move fastest dataset serialization tests to kernel tests
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 7496269..14439f8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -501,8 +501,10 @@
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/experimental/ops:batching",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
index 22cea18..d15fdbd 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
@@ -35,10 +35,12 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@@ -674,6 +676,100 @@
self.run_core_tests(build_dataset, 20)
+class ChooseFastestBranchDatasetCheckpointTest(
+ checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testCore(self):
+
+ def build_ds(size):
+ dataset = dataset_ops.Dataset.range(size)
+
+ def branch_0(dataset):
+ return dataset.map(lambda x: x).batch(10)
+
+ def branch_1(dataset):
+ return dataset.batch(10).map(lambda x: x)
+
+ return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
+ dataset, [branch_0, branch_1],
+ ratio_numerator=10)
+
+ for size in [100, 1000]:
+ self.run_core_tests(lambda: build_ds(size), size // 10) # pylint: disable=cell-var-from-loop
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testWithCapture(self):
+
+ def build_ds():
+ dataset = dataset_ops.Dataset.range(10)
+ const_64 = constant_op.constant(1, dtypes.int64)
+ const_32 = constant_op.constant(1, dtypes.int32)
+
+ def branch_0(dataset):
+ return dataset.map(lambda x: x + const_64)
+
+ def branch_1(dataset):
+ return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
+
+ return optimization._ChooseFastestBranchDataset(
+ dataset, [branch_0, branch_1], num_elements_per_branch=3)
+
+ self.run_core_tests(build_ds, 10)
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testWithPrefetch(self):
+
+ def build_ds():
+ dataset = dataset_ops.Dataset.range(10)
+ const_64 = constant_op.constant(1, dtypes.int64)
+ const_32 = constant_op.constant(1, dtypes.int32)
+
+ def branch_0(dataset):
+ return dataset.map(lambda x: x + const_64)
+
+ def branch_1(dataset):
+ return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
+
+ return optimization._ChooseFastestBranchDataset(
+ dataset, [branch_0, branch_1], num_elements_per_branch=3)
+
+ self.run_core_tests(build_ds, 10)
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testWithMoreOutputThanInput(self):
+
+ def build_ds():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
+
+ def branch(dataset):
+ return dataset.unbatch()
+
+ return optimization._ChooseFastestBranchDataset(
+ dataset, [branch, branch],
+ ratio_denominator=10,
+ num_elements_per_branch=100)
+
+ self.run_core_tests(build_ds, 1000)
+
+class ChooseFastestDatasetCheckpointTest(
+ checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
+
+ @combinations.generate(test_base.default_test_combinations())
+ def testCore(self):
+ num_outputs = 10
+ batch_size = 2
+
+ def build_ds():
+ dataset = dataset_ops.Dataset.range(num_outputs)
+ map_fn = lambda x: x * 2
+ return optimization._ChooseFastestDataset([ # pylint: disable=protected-access
+ dataset.map(map_fn).batch(batch_size),
+ dataset.batch(batch_size).map(map_fn)
+ ])
+
+ self.run_core_tests(build_ds, num_outputs // 2)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index 8284da1..ef821b7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -31,43 +31,6 @@
)
tf_py_test(
- name = "choose_fastest_branch_dataset_serialization_test",
- size = "medium",
- srcs = ["choose_fastest_branch_dataset_serialization_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/experimental/ops:batching",
- "//tensorflow/python/data/experimental/ops:optimization",
- "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "choose_fastest_dataset_serialization_test",
- size = "medium",
- srcs = ["choose_fastest_dataset_serialization_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/experimental/ops:optimization",
- "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
name = "sample_from_datasets_serialization_test",
size = "medium",
srcs = ["sample_from_datasets_serialization_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py
deleted file mode 100644
index 7da6546..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for checkpointing the ChooseFastestBranchDataset."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.experimental.ops import optimization
-from tensorflow.python.data.kernel_tests import checkpoint_test_base
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import combinations
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class ChooseFastestBranchDatasetCheckpointTest(
- checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
-
- @combinations.generate(test_base.default_test_combinations())
- def testCore(self):
-
- def build_ds(size):
- dataset = dataset_ops.Dataset.range(size)
-
- def branch_0(dataset):
- return dataset.map(lambda x: x).batch(10)
-
- def branch_1(dataset):
- return dataset.batch(10).map(lambda x: x)
-
- return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
- dataset, [branch_0, branch_1],
- ratio_numerator=10)
-
- for size in [100, 1000]:
- self.run_core_tests(lambda: build_ds(size), size // 10) # pylint: disable=cell-var-from-loop
-
- @combinations.generate(test_base.default_test_combinations())
- def testWithCapture(self):
-
- def build_ds():
- dataset = dataset_ops.Dataset.range(10)
- const_64 = constant_op.constant(1, dtypes.int64)
- const_32 = constant_op.constant(1, dtypes.int32)
-
- def branch_0(dataset):
- return dataset.map(lambda x: x + const_64)
-
- def branch_1(dataset):
- return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
-
- return optimization._ChooseFastestBranchDataset(
- dataset, [branch_0, branch_1], num_elements_per_branch=3)
-
- self.run_core_tests(build_ds, 10)
-
- @combinations.generate(test_base.default_test_combinations())
- def testWithPrefetch(self):
-
- def build_ds():
- dataset = dataset_ops.Dataset.range(10)
- const_64 = constant_op.constant(1, dtypes.int64)
- const_32 = constant_op.constant(1, dtypes.int32)
-
- def branch_0(dataset):
- return dataset.map(lambda x: x + const_64)
-
- def branch_1(dataset):
- return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
-
- return optimization._ChooseFastestBranchDataset(
- dataset, [branch_0, branch_1], num_elements_per_branch=3)
-
- self.run_core_tests(build_ds, 10)
-
- @combinations.generate(test_base.default_test_combinations())
- def testWithMoreOutputThanInput(self):
-
- def build_ds():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
-
- def branch(dataset):
- return dataset.unbatch()
-
- return optimization._ChooseFastestBranchDataset(
- dataset, [branch, branch],
- ratio_denominator=10,
- num_elements_per_branch=100)
-
- self.run_core_tests(build_ds, 1000)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py
deleted file mode 100644
index 0c4de60..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for checkpointing the ChooseFastestDataset."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.experimental.ops import optimization
-from tensorflow.python.data.kernel_tests import checkpoint_test_base
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import combinations
-from tensorflow.python.platform import test
-
-
-class ChooseFastestDatasetCheckpointTest(
- checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
-
- @combinations.generate(test_base.default_test_combinations())
- def testCore(self):
- num_outputs = 10
- batch_size = 2
-
- def build_ds():
- dataset = dataset_ops.Dataset.range(num_outputs)
- map_fn = lambda x: x * 2
- return optimization._ChooseFastestDataset([ # pylint: disable=protected-access
- dataset.map(map_fn).batch(batch_size),
- dataset.batch(batch_size).map(map_fn)
- ])
-
- self.run_core_tests(build_ds, num_outputs // 2)
-
-
-if __name__ == "__main__":
- test.main()