restructured test combinations
diff --git a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
index 61b089b..912fa31 100644
--- a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
@@ -60,7 +60,7 @@
]
-def _test_eager_objects():
+def _test_eager_only_objects():
return [
combinations.NamedObject(
"ragged",
@@ -80,7 +80,9 @@
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
- combinations.combine(element=_test_objects())))
+ combinations.combine(element=_test_objects())) +
+ combinations.times(test_base.eager_only_combinations(),
+ combinations.combine(element=_test_eager_only_objects())))
def testCompression(self, element):
element = element._obj
@@ -91,7 +93,9 @@
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
- combinations.combine(element=_test_objects())))
+ combinations.combine(element=_test_objects())) +
+ combinations.times(test_base.eager_only_combinations(),
+ combinations.combine(element=_test_eager_only_objects())))
def testDatasetCompression(self, element):
element = element._obj
@@ -102,30 +106,6 @@
dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec))
self.assertDatasetProduces(dataset, [element])
- @combinations.generate(combinations.times(
- test_base.eager_only_combinations(),
- combinations.combine(element=_test_objects() + _test_eager_objects())))
- def testCompressionEager(self, element):
- element = element._obj
-
- compressed = compression_ops.compress(element)
- uncompressed = compression_ops.uncompress(
- compressed, structure.type_spec_from_value(element))
- self.assertValuesEqual(element, self.evaluate(uncompressed))
-
- @combinations.generate(combinations.times(
- test_base.eager_only_combinations(),
- combinations.combine(element=_test_objects() + _test_eager_objects())))
- def testDatasetCompressionEager(self, element):
- element = element._obj
-
- dataset = dataset_ops.Dataset.from_tensors(element)
- element_spec = dataset.element_spec
-
- dataset = dataset.map(lambda *x: compression_ops.compress(x))
- dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec))
- self.assertDatasetProduces(dataset, [element])
-
if __name__ == "__main__":
test.main()