Switch WrapDatasetVariantTest to use TF combinations
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 2c08535..a2cc54d 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -32,7 +32,6 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops
@@ -41,7 +40,7 @@
from tensorflow.python.platform import test
-@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage")
+# TODO(b/123903858): Add eager and V2 test coverage
class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
diff --git a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
index 09627d0..b65c0fb 100644
--- a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
@@ -17,18 +17,20 @@
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import combinations
from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.platform import test
-@test_util.run_all_in_graph_and_eager_modes
-class WrapDatasetVariantTest(test_base.DatasetTestBase):
+class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
+ @combinations.generate(test_base.default_test_combinations())
def testBasic(self):
ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access
@@ -42,7 +44,9 @@
for i in range(100):
self.assertEqual(i, self.evaluate(get_next()))
- @test_util.run_v1_only("b/123901304")
+ # TODO("b/123901304")
+ @combinations.generate(
+ combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testSkipEagerGPU(self):
ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access