[tf.data] Refactoring optimization test methods.

This CL breaks down large tests that iterate over different test cases into smaller ones -- one per test case.

PiperOrigin-RevId: 283993741
Change-Id: I0e67958279d924d0b139164108e971bf39de96ca
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
index 949f9e2..1df52da 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
@@ -17,6 +17,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import testing
@@ -29,12 +31,42 @@
 from tensorflow.python.platform import test
 
 
+def _test_combinations():
+  cases = []
+
+  take_all = lambda x: constant_op.constant(True)
+  is_zero = lambda x: math_ops.equal(x, 0)
+  greater = lambda x: math_ops.greater(x + 5, 0)
+  predicates = [take_all, is_zero, greater]
+  for i, x in enumerate(predicates):
+    for j, y in enumerate(predicates):
+      cases.append((lambda x: x, "Scalar{}{}".format(i, j), [x, y]))
+      for k, z in enumerate(predicates):
+        cases.append((lambda x: x, "Scalar{}{}{}".format(i, j, k), [x, y, z]))
+
+  take_all = lambda x, y: constant_op.constant(True)
+  is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+
+  cases.append((lambda x: (x, x), "Tuple1", [take_all, take_all]))
+  cases.append((lambda x: (x, 2), "Tuple2", [take_all, is_zero]))
+
+  def reduce_fn(x, y):
+    function, name, predicates = y
+    return x + combinations.combine(
+        function=function,
+        predicates=combinations.NamedObject(name, predicates))
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def _testFilterFusion(self, map_function, predicates):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _test_combinations()))
+  def testFilterFusion(self, function, predicates):
     dataset = dataset_ops.Dataset.range(5).apply(
-        testing.assert_next(["Map", "Filter",
-                             "MemoryCacheImpl"])).map(map_function)
+        testing.assert_next(["Map", "Filter", "MemoryCacheImpl"])).map(function)
     for predicate in predicates:
       dataset = dataset.filter(predicate)
 
@@ -45,7 +77,7 @@
     dataset = dataset.with_options(options)
     expected_output = []
     for x in range(5):
-      r = map_function(x)
+      r = function(x)
       filtered = False
       for predicate in predicates:
         if isinstance(r, tuple):
@@ -60,26 +92,6 @@
         expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testFilterFusionScalar(self):
-    take_all = lambda x: constant_op.constant(True)
-    is_zero = lambda x: math_ops.equal(x, 0)
-    greater = lambda x: math_ops.greater(x + 5, 0)
-    predicates = [take_all, is_zero, greater]
-    for x in predicates:
-      for y in predicates:
-        self._testFilterFusion(lambda x: x, [x, y])
-        for z in predicates:
-          self._testFilterFusion(lambda x: x, [x, y, z])
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testFilterFusionTuple(self):
-    take_all = lambda x, y: constant_op.constant(True)
-    is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
-
-    self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
-    self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
index 59f50fa..1097b1e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -17,6 +17,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import testing
@@ -33,6 +35,36 @@
 from tensorflow.python.platform import test
 
 
+def _test_combinations():
+  def random(_):
+    return random_ops.random_uniform([],
+                                     minval=1,
+                                     maxval=10,
+                                     dtype=dtypes.float32,
+                                     seed=42)
+
+  def random_with_assert(x):
+    y = random(x)
+    assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+    with ops.control_dependencies([assert_op]):
+      return y
+
+  cases = [
+      ("Increment", lambda x: x + 1, False),
+      ("Random", random, True),
+      ("RandomWithAssert", random_with_assert, True),
+      ("Complex", lambda x: (random(x) + random(x)) / 2, False),
+  ]
+
+  def reduce_fn(x, y):
+    name, map_fn, should_optimize = y
+    return x + combinations.combine(
+        map_fn=combinations.NamedObject(name, map_fn),
+        should_optimize=should_optimize)
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _testDataset(self, dataset):
@@ -51,10 +83,13 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(get_next())
 
-  def _testHoistFunction(self, function, should_optimize):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _test_combinations()))
+  def testHoistFunction(self, map_fn, should_optimize):
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(
-            ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
+            ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(map_fn)
 
     options = dataset_ops.Options()
     options.experimental_optimization.apply_default_optimizations = False
@@ -63,31 +98,6 @@
     self._testDataset(dataset)
 
   @combinations.generate(test_base.default_test_combinations())
-  def testNoRandom(self):
-    self._testHoistFunction(lambda x: x + 1, should_optimize=False)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testRandom(self):
-
-    def random(_):
-      return random_ops.random_uniform([],
-                                       minval=1,
-                                       maxval=10,
-                                       dtype=dtypes.float32,
-                                       seed=42)
-
-    def random_with_assert(x):
-      y = random(x)
-      assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
-      with ops.control_dependencies([assert_op]):
-        return y
-
-    self._testHoistFunction(random, should_optimize=True)
-    self._testHoistFunction(random_with_assert, should_optimize=True)
-    self._testHoistFunction(
-        lambda x: (random(x) + random(x)) / 2, should_optimize=False)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testCapturedInputs(self):
     a = constant_op.constant(1, dtype=dtypes.float32)
     b = constant_op.constant(0, dtype=dtypes.float32)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
index a0257f7..aa0ab40 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -17,6 +17,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import testing
@@ -29,6 +31,49 @@
 from tensorflow.python.platform import test
 
 
+def _test_combinations():
+  cases = []
+
+  identity = lambda x: x
+  increment = lambda x: x + 1
+  minus_five = lambda x: x - 5
+
+  def increment_and_square(x):
+    y = x + 1
+    return y * y
+
+  functions = [identity, increment, minus_five, increment_and_square]
+
+  take_all = lambda x: constant_op.constant(True)
+  is_zero = lambda x: math_ops.equal(x, 0)
+  is_odd = lambda x: math_ops.equal(x % 2, 0)
+  greater = lambda x: math_ops.greater(x + 5, 0)
+  predicates = [take_all, is_zero, is_odd, greater]
+
+  for i, function in enumerate(functions):
+    for j, predicate in enumerate(predicates):
+      cases.append((function, "Scalar{}{}".format(i, j), predicate))
+
+  replicate = lambda x: (x, x)
+  with_two = lambda x: (x, 2)
+  functions = [replicate, with_two]
+  take_all = lambda x, y: constant_op.constant(True)
+  is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+  predicates = [take_all, is_zero]
+
+  for i, function in enumerate(functions):
+    for j, predicate in enumerate(predicates):
+      cases.append((function, "Tuple{}{}".format(i, j), predicate))
+
+  def reduce_fn(x, y):
+    function, name, predicate = y
+    return x + combinations.combine(
+        function=function,
+        predicate=combinations.NamedObject(name, predicate))
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def _testDataset(self, dataset, function, predicate):
@@ -43,7 +88,10 @@
         expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  def _testMapAndFilterFusion(self, function, predicate):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _test_combinations()))
+  def testMapAndFilterFusion(self, function, predicate):
     dataset = dataset_ops.Dataset.range(10).apply(
         testing.assert_next(["Map", "Filter",
                              "Map"])).map(function).filter(predicate)
@@ -54,41 +102,6 @@
     self._testDataset(dataset, function, predicate)
 
   @combinations.generate(test_base.default_test_combinations())
-  def testMapAndFilterFusionScalar(self):
-    identity = lambda x: x
-    increment = lambda x: x + 1
-    minus_five = lambda x: x - 5
-
-    def increment_and_square(x):
-      y = x + 1
-      return y * y
-
-    functions = [identity, increment, minus_five, increment_and_square]
-
-    take_all = lambda x: constant_op.constant(True)
-    is_zero = lambda x: math_ops.equal(x, 0)
-    is_odd = lambda x: math_ops.equal(x % 2, 0)
-    greater = lambda x: math_ops.greater(x + 5, 0)
-    predicates = [take_all, is_zero, is_odd, greater]
-
-    for function in functions:
-      for predicate in predicates:
-        self._testMapAndFilterFusion(function, predicate)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testMapAndFilterFusionTuple(self):
-    replicate = lambda x: (x, x)
-    with_two = lambda x: (x, 2)
-    functions = [replicate, with_two]
-    take_all = lambda x, y: constant_op.constant(True)
-    is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
-    predicates = [take_all, is_zero]
-
-    for function in functions:
-      for predicate in predicates:
-        self._testMapAndFilterFusion(function, predicate)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testCapturedInputs(self):
     a = constant_op.constant(3, dtype=dtypes.int64)
     b = constant_op.constant(4, dtype=dtypes.int64)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
index 28da047..efe9c48 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
@@ -17,6 +17,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import testing
@@ -26,9 +28,44 @@
 from tensorflow.python.platform import test
 
 
+def _test_combinations():
+  cases = []
+
+  identity = lambda x: x
+  increment = lambda x: x + 1
+
+  def increment_and_square(x):
+    y = x + 1
+    return y * y
+
+  functions = [identity, increment, increment_and_square]
+
+  for i, x in enumerate(functions):
+    for j, y in enumerate(functions):
+      cases.append(("Scalar{}{}".format(i, j), [x, y]))
+      for k, z in enumerate(functions):
+        cases.append(("Scalar{}{}{}".format(i, j, k), [x, y, z]))
+
+  with_42 = lambda x: (x, 42)
+  swap = lambda x, y: (y, x)
+
+  cases.append(("Tuple1", [with_42, swap]))
+  cases.append(("Tuple2", [with_42, swap, swap]))
+
+  def reduce_fn(x, y):
+    name, functions = y
+    return x + combinations.combine(
+        functions=combinations.NamedObject(name, functions))
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def _testMapFusion(self, functions):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _test_combinations()))
+  def testMapFusion(self, functions):
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(["Map", "MemoryCacheImpl"]))
     for function in functions:
@@ -50,31 +87,6 @@
       expected_output.append(r)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testMapFusionScalar(self):
-    identity = lambda x: x
-    increment = lambda x: x + 1
-
-    def increment_and_square(x):
-      y = x + 1
-      return y * y
-
-    functions = [identity, increment, increment_and_square]
-
-    for x in functions:
-      for y in functions:
-        self._testMapFusion([x, y])
-        for z in functions:
-          self._testMapFusion([x, y, z])
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testMapAndFilterFusionTuple(self):
-    with_42 = lambda x: (x, 42)
-    swap = lambda x, y: (y, x)
-
-    self._testMapFusion([with_42, swap])
-    self._testMapFusion([with_42, swap, swap])
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
index a28a305..ac92dde 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -17,6 +17,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import testing
@@ -32,9 +34,33 @@
 from tensorflow.python.platform import test
 
 
+def _test_combinations():
+  def assert_greater(x):
+    assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+    with ops.control_dependencies([assert_op]):
+      return x
+
+  cases = [
+      ("Identity", lambda x: x, True),
+      ("Increment", lambda x: x + 1, True),
+      ("AssertGreater", assert_greater, True),
+  ]
+
+  def reduce_fn(x, y):
+    name, function, should_optimize = y
+    return x + combinations.combine(
+        function=combinations.NamedObject(name, function),
+        should_optimize=should_optimize)
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 
-  def _testMapParallelization(self, function, should_optimize):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _test_combinations()))
+  def testMapParallelization(self, function, should_optimize):
     next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
     dataset = dataset_ops.Dataset.range(5).apply(
         testing.assert_next(next_nodes)).map(function)
@@ -46,24 +72,6 @@
         dataset, expected_output=[function(x) for x in range(5)])
 
   @combinations.generate(test_base.default_test_combinations())
-  def testIdentity(self):
-    self._testMapParallelization(lambda x: x, should_optimize=True)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testIncrement(self):
-    self._testMapParallelization(lambda x: x + 1, should_optimize=True)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testAssert(self):
-
-    def assert_greater(x):
-      assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
-      with ops.control_dependencies([assert_op]):
-        return x
-
-    self._testMapParallelization(assert_greater, should_optimize=True)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testCapturedConstant(self):
     captured_t = constant_op.constant(42, dtype=dtypes.int64)
     def fn(x):
diff --git a/tensorflow/python/framework/test_combinations.py b/tensorflow/python/framework/test_combinations.py
index 95a3dc4..0986585 100644
--- a/tensorflow/python/framework/test_combinations.py
+++ b/tensorflow/python/framework/test_combinations.py
@@ -400,6 +400,9 @@
   def __call__(self, *args, **kwargs):
     return self._obj(*args, **kwargs)
 
+  def __iter__(self):
+    return self._obj.__iter__()
+
   def __repr__(self):
     return self._name