[tf.data] Refactoring optimization test methods.
This CL breaks down large tests that iterate over different test cases into smaller ones -- one per test case.
PiperOrigin-RevId: 283993741
Change-Id: I0e67958279d924d0b139164108e971bf39de96ca
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
index 949f9e2..1df52da 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@@ -29,12 +31,42 @@
from tensorflow.python.platform import test
+def _test_combinations():
+ cases = []
+
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+ predicates = [take_all, is_zero, greater]
+ for i, x in enumerate(predicates):
+ for j, y in enumerate(predicates):
+ cases.append((lambda x: x, "Scalar{}{}".format(i, j), [x, y]))
+ for k, z in enumerate(predicates):
+ cases.append((lambda x: x, "Scalar{}{}{}".format(i, j, k), [x, y, z]))
+
+ take_all = lambda x, y: constant_op.constant(True)
+ is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+
+ cases.append((lambda x: (x, x), "Tuple1", [take_all, take_all]))
+ cases.append((lambda x: (x, 2), "Tuple2", [take_all, is_zero]))
+
+ def reduce_fn(x, y):
+ function, name, predicates = y
+ return x + combinations.combine(
+ function=function,
+ predicates=combinations.NamedObject(name, predicates))
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
- def _testFilterFusion(self, map_function, predicates):
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ _test_combinations()))
+ def testFilterFusion(self, function, predicates):
dataset = dataset_ops.Dataset.range(5).apply(
- testing.assert_next(["Map", "Filter",
- "MemoryCacheImpl"])).map(map_function)
+ testing.assert_next(["Map", "Filter", "MemoryCacheImpl"])).map(function)
for predicate in predicates:
dataset = dataset.filter(predicate)
@@ -45,7 +77,7 @@
dataset = dataset.with_options(options)
expected_output = []
for x in range(5):
- r = map_function(x)
+ r = function(x)
filtered = False
for predicate in predicates:
if isinstance(r, tuple):
@@ -60,26 +92,6 @@
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
- @combinations.generate(test_base.default_test_combinations())
- def testFilterFusionScalar(self):
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
- predicates = [take_all, is_zero, greater]
- for x in predicates:
- for y in predicates:
- self._testFilterFusion(lambda x: x, [x, y])
- for z in predicates:
- self._testFilterFusion(lambda x: x, [x, y, z])
-
- @combinations.generate(test_base.default_test_combinations())
- def testFilterFusionTuple(self):
- take_all = lambda x, y: constant_op.constant(True)
- is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
-
- self._testFilterFusion(lambda x: (x, x), [take_all, take_all])
- self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero])
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
index 59f50fa..1097b1e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@@ -33,6 +35,36 @@
from tensorflow.python.platform import test
+def _test_combinations():
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=1,
+ maxval=10,
+ dtype=dtypes.float32,
+ seed=42)
+
+ def random_with_assert(x):
+ y = random(x)
+ assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+ with ops.control_dependencies([assert_op]):
+ return y
+
+ cases = [
+ ("Increment", lambda x: x + 1, False),
+ ("Random", random, True),
+ ("RandomWithAssert", random_with_assert, True),
+ ("Complex", lambda x: (random(x) + random(x)) / 2, False),
+ ]
+
+ def reduce_fn(x, y):
+ name, map_fn, should_optimize = y
+ return x + combinations.combine(
+ map_fn=combinations.NamedObject(name, map_fn),
+ should_optimize=should_optimize)
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testDataset(self, dataset):
@@ -51,10 +83,13 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
- def _testHoistFunction(self, function, should_optimize):
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ _test_combinations()))
+ def testHoistFunction(self, map_fn, should_optimize):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(
- ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function)
+ ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(map_fn)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
@@ -63,31 +98,6 @@
self._testDataset(dataset)
@combinations.generate(test_base.default_test_combinations())
- def testNoRandom(self):
- self._testHoistFunction(lambda x: x + 1, should_optimize=False)
-
- @combinations.generate(test_base.default_test_combinations())
- def testRandom(self):
-
- def random(_):
- return random_ops.random_uniform([],
- minval=1,
- maxval=10,
- dtype=dtypes.float32,
- seed=42)
-
- def random_with_assert(x):
- y = random(x)
- assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
- with ops.control_dependencies([assert_op]):
- return y
-
- self._testHoistFunction(random, should_optimize=True)
- self._testHoistFunction(random_with_assert, should_optimize=True)
- self._testHoistFunction(
- lambda x: (random(x) + random(x)) / 2, should_optimize=False)
-
- @combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(1, dtype=dtypes.float32)
b = constant_op.constant(0, dtype=dtypes.float32)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
index a0257f7..aa0ab40 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@@ -29,6 +31,49 @@
from tensorflow.python.platform import test
+def _test_combinations():
+ cases = []
+
+ identity = lambda x: x
+ increment = lambda x: x + 1
+ minus_five = lambda x: x - 5
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ functions = [identity, increment, minus_five, increment_and_square]
+
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ is_odd = lambda x: math_ops.equal(x % 2, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+ predicates = [take_all, is_zero, is_odd, greater]
+
+ for i, function in enumerate(functions):
+ for j, predicate in enumerate(predicates):
+ cases.append((function, "Scalar{}{}".format(i, j), predicate))
+
+ replicate = lambda x: (x, x)
+ with_two = lambda x: (x, 2)
+ functions = [replicate, with_two]
+ take_all = lambda x, y: constant_op.constant(True)
+ is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ predicates = [take_all, is_zero]
+
+ for i, function in enumerate(functions):
+ for j, predicate in enumerate(predicates):
+ cases.append((function, "Tuple{}{}".format(i, j), predicate))
+
+ def reduce_fn(x, y):
+ function, name, predicate = y
+ return x + combinations.combine(
+ function=function,
+ predicate=combinations.NamedObject(name, predicate))
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
def _testDataset(self, dataset, function, predicate):
@@ -43,7 +88,10 @@
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
- def _testMapAndFilterFusion(self, function, predicate):
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ _test_combinations()))
+ def testMapAndFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
testing.assert_next(["Map", "Filter",
"Map"])).map(function).filter(predicate)
@@ -54,41 +102,6 @@
self._testDataset(dataset, function, predicate)
@combinations.generate(test_base.default_test_combinations())
- def testMapAndFilterFusionScalar(self):
- identity = lambda x: x
- increment = lambda x: x + 1
- minus_five = lambda x: x - 5
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- functions = [identity, increment, minus_five, increment_and_square]
-
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- is_odd = lambda x: math_ops.equal(x % 2, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
- predicates = [take_all, is_zero, is_odd, greater]
-
- for function in functions:
- for predicate in predicates:
- self._testMapAndFilterFusion(function, predicate)
-
- @combinations.generate(test_base.default_test_combinations())
- def testMapAndFilterFusionTuple(self):
- replicate = lambda x: (x, x)
- with_two = lambda x: (x, 2)
- functions = [replicate, with_two]
- take_all = lambda x, y: constant_op.constant(True)
- is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
- predicates = [take_all, is_zero]
-
- for function in functions:
- for predicate in predicates:
- self._testMapAndFilterFusion(function, predicate)
-
- @combinations.generate(test_base.default_test_combinations())
def testCapturedInputs(self):
a = constant_op.constant(3, dtype=dtypes.int64)
b = constant_op.constant(4, dtype=dtypes.int64)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
index 28da047..efe9c48 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@@ -26,9 +28,44 @@
from tensorflow.python.platform import test
+def _test_combinations():
+ cases = []
+
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ functions = [identity, increment, increment_and_square]
+
+ for i, x in enumerate(functions):
+ for j, y in enumerate(functions):
+ cases.append(("Scalar{}{}".format(i, j), [x, y]))
+ for k, z in enumerate(functions):
+ cases.append(("Scalar{}{}{}".format(i, j, k), [x, y, z]))
+
+ with_42 = lambda x: (x, 42)
+ swap = lambda x, y: (y, x)
+
+ cases.append(("Tuple1", [with_42, swap]))
+ cases.append(("Tuple2", [with_42, swap, swap]))
+
+ def reduce_fn(x, y):
+ name, functions = y
+ return x + combinations.combine(
+ functions=combinations.NamedObject(name, functions))
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
- def _testMapFusion(self, functions):
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ _test_combinations()))
+ def testMapFusion(self, functions):
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(["Map", "MemoryCacheImpl"]))
for function in functions:
@@ -50,31 +87,6 @@
expected_output.append(r)
self.assertDatasetProduces(dataset, expected_output=expected_output)
- @combinations.generate(test_base.default_test_combinations())
- def testMapFusionScalar(self):
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- functions = [identity, increment, increment_and_square]
-
- for x in functions:
- for y in functions:
- self._testMapFusion([x, y])
- for z in functions:
- self._testMapFusion([x, y, z])
-
- @combinations.generate(test_base.default_test_combinations())
- def testMapAndFilterFusionTuple(self):
- with_42 = lambda x: (x, 42)
- swap = lambda x, y: (y, x)
-
- self._testMapFusion([with_42, swap])
- self._testMapFusion([with_42, swap, swap])
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
index a28a305..ac92dde 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
@@ -32,9 +34,33 @@
from tensorflow.python.platform import test
+def _test_combinations():
+ def assert_greater(x):
+ assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+ with ops.control_dependencies([assert_op]):
+ return x
+
+ cases = [
+ ("Identity", lambda x: x, True),
+ ("Increment", lambda x: x + 1, True),
+ ("AssertGreater", assert_greater, True),
+ ]
+
+ def reduce_fn(x, y):
+ name, function, should_optimize = y
+ return x + combinations.combine(
+ function=combinations.NamedObject(name, function),
+ should_optimize=should_optimize)
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
- def _testMapParallelization(self, function, should_optimize):
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ _test_combinations()))
+ def testMapParallelization(self, function, should_optimize):
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
dataset = dataset_ops.Dataset.range(5).apply(
testing.assert_next(next_nodes)).map(function)
@@ -46,24 +72,6 @@
dataset, expected_output=[function(x) for x in range(5)])
@combinations.generate(test_base.default_test_combinations())
- def testIdentity(self):
- self._testMapParallelization(lambda x: x, should_optimize=True)
-
- @combinations.generate(test_base.default_test_combinations())
- def testIncrement(self):
- self._testMapParallelization(lambda x: x + 1, should_optimize=True)
-
- @combinations.generate(test_base.default_test_combinations())
- def testAssert(self):
-
- def assert_greater(x):
- assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
- with ops.control_dependencies([assert_op]):
- return x
-
- self._testMapParallelization(assert_greater, should_optimize=True)
-
- @combinations.generate(test_base.default_test_combinations())
def testCapturedConstant(self):
captured_t = constant_op.constant(42, dtype=dtypes.int64)
def fn(x):
diff --git a/tensorflow/python/framework/test_combinations.py b/tensorflow/python/framework/test_combinations.py
index 95a3dc4..0986585 100644
--- a/tensorflow/python/framework/test_combinations.py
+++ b/tensorflow/python/framework/test_combinations.py
@@ -400,6 +400,9 @@
def __call__(self, *args, **kwargs):
return self._obj(*args, **kwargs)
+ def __iter__(self):
+ return self._obj.__iter__()
+
def __repr__(self):
return self._name