move snapshot_test to data/kernel_tests
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 31a4706..e610db2 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -590,31 +590,6 @@
 )
 
 tf_py_test(
-    name = "snapshot_test",
-    size = "small",
-    timeout = "long",
-    srcs = ["snapshot_test.py"],
-    shard_count = 16,
-    tags = [
-        "no_windows",  # TODO(b/182379890)
-    ],
-    deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:string_ops",
-        "//tensorflow/python/data/experimental/ops:readers",
-        "//tensorflow/python/data/experimental/ops:snapshot",
-        "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
-        "//tensorflow/python/data/kernel_tests:test_base",
-        "//tensorflow/python/data/kernel_tests:tf_record_test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/data/ops:readers",
-        "//tensorflow/python/data/util:nest",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
-tf_py_test(
     name = "sql_dataset_test",
     size = "small",
     srcs = ["sql_dataset_test.py"],
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 0205ed5..e6aba3b 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -854,6 +854,31 @@
 )
 
 tf_py_test(
+    name = "snapshot_test",
+    size = "small",
+    timeout = "long",
+    srcs = ["snapshot_test.py"],
+    shard_count = 16,
+    tags = [
+        "no_windows",  # TODO(b/182379890)
+    ],
+    deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:string_ops",
+        "//tensorflow/python/data/experimental/ops:readers",
+        "//tensorflow/python/data/experimental/ops:snapshot",
+        "//tensorflow/python/data/kernel_tests:checkpoint_test_base",
+        "//tensorflow/python/data/kernel_tests:test_base",
+        "//tensorflow/python/data/kernel_tests:tf_record_test_base",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/ops:readers",
+        "//tensorflow/python/data/util:nest",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+tf_py_test(
     name = "take_test",
     size = "small",
     srcs = ["take_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/kernel_tests/snapshot_test.py
similarity index 95%
rename from tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
rename to tensorflow/python/data/kernel_tests/snapshot_test.py
index f720966..71429d4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/kernel_tests/snapshot_test.py
@@ -119,7 +119,7 @@
   @combinations.generate(test_base.default_test_combinations())
   def testCreateSnapshotDataset(self):
     dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
-    dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset.snapshot(path=self._snapshot_dir)
 
   @combinations.generate(test_base.default_test_combinations())
   def testReadSnapshotDatasetDefault(self):
@@ -132,7 +132,7 @@
     ]
 
     dataset = core_readers._TFRecordDataset(filenames)
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset, expected)
     self.assertSnapshotDirectoryContains(
         self._snapshot_dir,
@@ -142,7 +142,7 @@
 
     self.removeTFRecords()
     dataset2 = core_readers._TFRecordDataset(filenames)
-    dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset2 = dataset2.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset2, expected)
 
   @combinations.generate(test_base.default_test_combinations())
@@ -156,14 +156,15 @@
     ]
 
     dataset = core_readers._TFRecordDataset(filenames)
-    dataset = dataset.apply(
-        snapshot.snapshot(self._snapshot_dir, compression="AUTO"))
+    dataset = dataset.snapshot(path=self._snapshot_dir, compression="AUTO")
     self.assertDatasetProduces(dataset, expected)
 
     self.removeTFRecords()
     dataset2 = core_readers._TFRecordDataset(filenames)
-    dataset2 = dataset2.apply(
-        snapshot.snapshot(self._snapshot_dir, compression="SNAPPY"))
+    dataset2 = dataset2.snapshot(
+        path=self._snapshot_dir,
+        compression="SNAPPY"
+    )
     self.assertDatasetProduces(dataset2, expected)
 
   @combinations.generate(test_base.default_test_combinations())
@@ -177,8 +178,10 @@
     ]
 
     dataset = core_readers._TFRecordDataset(filenames)
-    dataset = dataset.apply(
-        snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0)))
+    dataset = dataset.snapshot(
+        path=self._snapshot_dir,
+        shard_func=lambda _: np.int64(0)
+    )
     self.assertDatasetProduces(dataset, expected)
     self.assertSnapshotDirectoryContains(
         self._snapshot_dir,
@@ -188,8 +191,10 @@
 
     self.removeTFRecords()
     dataset2 = core_readers._TFRecordDataset(filenames)
-    dataset2 = dataset2.apply(
-        snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
+    dataset2 = dataset2.snapshot(
+        path=self._snapshot_dir,
+        shard_func=lambda _: 0
+    )
     self.assertDatasetProduces(dataset2, expected)
 
   @combinations.generate(test_base.default_test_combinations())
@@ -203,14 +208,13 @@
     ]
 
     dataset = core_readers._TFRecordDataset(filenames)
-    dataset = dataset.apply(
-        snapshot.snapshot(
-            self._snapshot_dir,
+    dataset = dataset.snapshot(
+            path=self._snapshot_dir,
             reader_func=(
                 lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                     lambda x: x,
                     cycle_length=4,
-                    num_parallel_calls=4))))
+                    num_parallel_calls=4)))
     self.assertDatasetProduces(dataset, expected)
     self.assertSnapshotDirectoryContains(
         self._snapshot_dir,
@@ -220,23 +224,22 @@
 
     self.removeTFRecords()
     dataset2 = core_readers._TFRecordDataset(filenames)
-    dataset2 = dataset2.apply(
-        snapshot.snapshot(
+    dataset2 = dataset2.snapshot(
             self._snapshot_dir,
             reader_func=(
                 lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                     lambda x: x,
                     cycle_length=4,
-                    num_parallel_calls=4))))
+                    num_parallel_calls=4)))
     self.assertDatasetProducesSet(dataset2, expected)
 
   @combinations.generate(test_base.default_test_combinations())
   def testSnapshotDatasetInvalidShardFn(self):
     dataset = dataset_ops.Dataset.range(1000)
     with self.assertRaises(TypeError):
-      dataset = dataset.apply(
-          snapshot.snapshot(
-              self._snapshot_dir, shard_func=lambda _: "invalid_fn"))
+      dataset = dataset.snapshot(
+              path=self._snapshot_dir,
+              shard_func=lambda _: "invalid_fn")
       next_fn = self.getNext(dataset)
       self.evaluate(next_fn())
 
@@ -244,15 +247,17 @@
   def testSnapshotDatasetInvalidReaderFn(self):
     dataset = dataset_ops.Dataset.range(1000)
     with self.assertRaises(TypeError):
-      dataset = dataset.apply(
-          snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
+      dataset = dataset.snapshot(
+          path=self._snapshot_dir,
+          reader_func=lambda x: x + 1
+      )
       next_fn = self.getNext(dataset)
       self.evaluate(next_fn())
 
   @combinations.generate(test_base.default_test_combinations())
   def testRoundtripEmptySnapshot(self):
     dataset = dataset_ops.Dataset.range(0)
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset, [])
     self.assertSnapshotDirectoryContains(
         self._snapshot_dir,
@@ -261,13 +266,13 @@
         num_snapshot_shards_per_run=0)
 
     dataset2 = dataset_ops.Dataset.range(0)
-    dataset2 = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset2 = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset2, [])
 
   @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotDatasetSimple(self):
     dataset = dataset_ops.Dataset.range(1000)
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset, list(range(1000)))
     self.assertSnapshotDirectoryContains(
         self._snapshot_dir,
@@ -278,11 +283,11 @@
   @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotDatasetMultipleFingerprints(self):
     dataset1 = dataset_ops.Dataset.range(1000)
-    dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset1 = dataset1.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset1, list(range(1000)))
 
     dataset2 = dataset_ops.Dataset.range(2000)
-    dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset2 = dataset2.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset2, list(range(2000)))
 
     self.assertSnapshotDirectoryContains(
@@ -294,10 +299,10 @@
   @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
     dataset1 = dataset_ops.Dataset.range(1000)
-    dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset1 = dataset1.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset1, list(range(1000)))
     dataset2 = dataset_ops.Dataset.range(1000)
-    dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset2 = dataset2.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset2, list(range(1000)))
 
     self.assertSnapshotDirectoryContains(
@@ -309,13 +314,13 @@
   @combinations.generate(test_base.default_test_combinations())
   def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
     dataset1 = dataset_ops.Dataset.range(1000)
-    dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset1 = dataset1.snapshot(path=self._snapshot_dir)
     next1 = self.getNext(dataset1)
     for i in range(500):
       self.assertEqual(i, self.evaluate(next1()))
 
     dataset2 = dataset_ops.Dataset.range(1000)
-    dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset2 = dataset2.snapshot(path=self._snapshot_dir)
     next2 = self.getNext(dataset2)
     for i in range(500):
       self.assertEqual(i, self.evaluate(next2()))
@@ -334,8 +339,10 @@
   def testWriteSnapshotCustomShardFunction(self):
     dataset = dataset_ops.Dataset.range(1000)
     dataset = dataset.enumerate()
-    dataset = dataset.apply(
-        snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2))
+    dataset = dataset.snapshot(
+        path=self._snapshot_dir,
+        shard_func=lambda i, _: i % 2
+    )
     dataset = dataset.map(lambda _, elem: elem)
     self.assertDatasetProduces(dataset, list(range(1000)))
     self.assertSnapshotDirectoryContains(
@@ -352,7 +359,7 @@
     dataset4 = dataset_ops.Dataset.range(3000, 4000)
 
     dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
 
     expected = list(
         zip(
@@ -371,7 +378,7 @@
     def make_dataset():
       dataset = dataset_ops.Dataset.range(1000)
       dataset = dataset.shuffle(1000)
-      dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+      dataset = dataset.snapshot(path=self._snapshot_dir)
       return dataset
 
     dataset1 = make_dataset()
@@ -387,7 +394,7 @@
   @combinations.generate(test_base.default_test_combinations())
   def testReadUsingFlatMap(self):
     dataset = dataset_ops.Dataset.range(1000)
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProduces(dataset, list(range(1000)))
     flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
     self.assertDatasetProduces(flat_map, list(range(1000)))
@@ -403,7 +410,7 @@
     # Will be optimized into ShuffleAndRepeat.
     dataset = dataset.shuffle(10)
     dataset = dataset.repeat(2)
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     self.assertDatasetProducesSet(dataset, 2 * list(range(1000)))
     flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
     self.assertDatasetProducesSet(flat_map, 2 * list(range(1000)))
@@ -417,7 +424,7 @@
   def testRepeatAndPrefetch(self):
     """This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
     dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
-    dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+    dataset = dataset.snapshot(path=self._snapshot_dir)
     dataset = dataset.shuffle(buffer_size=16)
     dataset = dataset.batch(16)
     dataset = dataset.repeat()
@@ -1031,7 +1038,7 @@
         os.mkdir(self._snapshot_dir)
 
       dataset = dataset_ops.Dataset.range(100)
-      dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
+      dataset = dataset.snapshot(path=self._snapshot_dir)
       if repeat:
         dataset = dataset.repeat(2)
       return dataset