move snapshot_test to data/kernel_tests
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 31a4706..e610db2 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -590,31 +590,6 @@
)
tf_py_test(
- name = "snapshot_test",
- size = "small",
- timeout = "long",
- srcs = ["snapshot_test.py"],
- shard_count = 16,
- tags = [
- "no_windows", # TODO(b/182379890)
- ],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/experimental/ops:readers",
- "//tensorflow/python/data/experimental/ops:snapshot",
- "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/kernel_tests:tf_record_test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-tf_py_test(
name = "sql_dataset_test",
size = "small",
srcs = ["sql_dataset_test.py"],
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 0205ed5..e6aba3b 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -854,6 +854,31 @@
)
tf_py_test(
+ name = "snapshot_test",
+ size = "small",
+ timeout = "long",
+ srcs = ["snapshot_test.py"],
+ shard_count = 16,
+ tags = [
+ "no_windows", # TODO(b/182379890)
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/experimental/ops:snapshot",
+ "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/kernel_tests:tf_record_test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_py_test(
name = "take_test",
size = "small",
srcs = ["take_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/kernel_tests/snapshot_test.py
similarity index 95%
rename from tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
rename to tensorflow/python/data/kernel_tests/snapshot_test.py
index f720966..71429d4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/kernel_tests/snapshot_test.py
@@ -119,7 +119,7 @@
@combinations.generate(test_base.default_test_combinations())
def testCreateSnapshotDataset(self):
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
- dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset.snapshot(path=self._snapshot_dir)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetDefault(self):
@@ -132,7 +132,7 @@
]
dataset = core_readers._TFRecordDataset(filenames)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
@@ -142,7 +142,7 @@
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
- dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
@@ -156,14 +156,15 @@
]
dataset = core_readers._TFRecordDataset(filenames)
- dataset = dataset.apply(
- snapshot.snapshot(self._snapshot_dir, compression="AUTO"))
+ dataset = dataset.snapshot(path=self._snapshot_dir, compression="AUTO")
self.assertDatasetProduces(dataset, expected)
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
- dataset2 = dataset2.apply(
- snapshot.snapshot(self._snapshot_dir, compression="SNAPPY"))
+ dataset2 = dataset2.snapshot(
+ path=self._snapshot_dir,
+ compression="SNAPPY"
+ )
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
@@ -177,8 +178,10 @@
]
dataset = core_readers._TFRecordDataset(filenames)
- dataset = dataset.apply(
- snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0)))
+ dataset = dataset.snapshot(
+ path=self._snapshot_dir,
+ shard_func=lambda _: np.int64(0)
+ )
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
@@ -188,8 +191,10 @@
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
- dataset2 = dataset2.apply(
- snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
+ dataset2 = dataset2.snapshot(
+ path=self._snapshot_dir,
+ shard_func=lambda _: 0
+ )
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
@@ -203,14 +208,13 @@
]
dataset = core_readers._TFRecordDataset(filenames)
- dataset = dataset.apply(
- snapshot.snapshot(
- self._snapshot_dir,
+ dataset = dataset.snapshot(
+ path=self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
- num_parallel_calls=4))))
+ num_parallel_calls=4)))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
@@ -220,23 +224,22 @@
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
- dataset2 = dataset2.apply(
- snapshot.snapshot(
+ dataset2 = dataset2.snapshot(
self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
- num_parallel_calls=4))))
+ num_parallel_calls=4)))
self.assertDatasetProducesSet(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testSnapshotDatasetInvalidShardFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
- dataset = dataset.apply(
- snapshot.snapshot(
- self._snapshot_dir, shard_func=lambda _: "invalid_fn"))
+ dataset = dataset.snapshot(
+ path=self._snapshot_dir,
+ shard_func=lambda _: "invalid_fn")
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@@ -244,15 +247,17 @@
def testSnapshotDatasetInvalidReaderFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
- dataset = dataset.apply(
- snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
+ dataset = dataset.snapshot(
+ path=self._snapshot_dir,
+ reader_func=lambda x: x + 1
+ )
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
def testRoundtripEmptySnapshot(self):
dataset = dataset_ops.Dataset.range(0)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, [])
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
@@ -261,13 +266,13 @@
num_snapshot_shards_per_run=0)
dataset2 = dataset_ops.Dataset.range(0)
- dataset2 = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset2 = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, [])
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSimple(self):
dataset = dataset_ops.Dataset.range(1000)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
@@ -278,11 +283,11 @@
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetMultipleFingerprints(self):
dataset1 = dataset_ops.Dataset.range(1000)
- dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset1 = dataset1.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(2000)
- dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, list(range(2000)))
self.assertSnapshotDirectoryContains(
@@ -294,10 +299,10 @@
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
dataset1 = dataset_ops.Dataset.range(1000)
- dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset1 = dataset1.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(1000)
- dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, list(range(1000)))
self.assertSnapshotDirectoryContains(
@@ -309,13 +314,13 @@
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
dataset1 = dataset_ops.Dataset.range(1000)
- dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset1 = dataset1.snapshot(path=self._snapshot_dir)
next1 = self.getNext(dataset1)
for i in range(500):
self.assertEqual(i, self.evaluate(next1()))
dataset2 = dataset_ops.Dataset.range(1000)
- dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset2 = dataset2.snapshot(path=self._snapshot_dir)
next2 = self.getNext(dataset2)
for i in range(500):
self.assertEqual(i, self.evaluate(next2()))
@@ -334,8 +339,10 @@
def testWriteSnapshotCustomShardFunction(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.enumerate()
- dataset = dataset.apply(
- snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2))
+ dataset = dataset.snapshot(
+ path=self._snapshot_dir,
+ shard_func=lambda i, _: i % 2
+ )
dataset = dataset.map(lambda _, elem: elem)
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
@@ -352,7 +359,7 @@
dataset4 = dataset_ops.Dataset.range(3000, 4000)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
expected = list(
zip(
@@ -371,7 +378,7 @@
def make_dataset():
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.shuffle(1000)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
return dataset
dataset1 = make_dataset()
@@ -387,7 +394,7 @@
@combinations.generate(test_base.default_test_combinations())
def testReadUsingFlatMap(self):
dataset = dataset_ops.Dataset.range(1000)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, list(range(1000)))
flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
self.assertDatasetProduces(flat_map, list(range(1000)))
@@ -403,7 +410,7 @@
# Will be optimized into ShuffleAndRepeat.
dataset = dataset.shuffle(10)
dataset = dataset.repeat(2)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProducesSet(dataset, 2 * list(range(1000)))
flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
self.assertDatasetProducesSet(flat_map, 2 * list(range(1000)))
@@ -417,7 +424,7 @@
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
@@ -1031,7 +1038,7 @@
os.mkdir(self._snapshot_dir)
dataset = dataset_ops.Dataset.range(100)
- dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+ dataset = dataset.snapshot(path=self._snapshot_dir)
if repeat:
dataset = dataset.repeat(2)
return dataset