| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for the `SnapshotDataset` transformation.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import multiprocessing |
| import os |
| import shutil |
| import time |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.python.data.experimental.ops import snapshot |
| from tensorflow.python.data.kernel_tests import checkpoint_test_base |
| from tensorflow.python.data.kernel_tests import test_base |
| from tensorflow.python.data.kernel_tests import tf_record_test_base |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import readers as core_readers |
| from tensorflow.python.framework import combinations |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import string_ops |
| from tensorflow.python.platform import test |
| |
| |
| def is_graphdef_file(filename): |
| return filename.endswith("-graph.pbtxt") |
| |
| |
| def is_temp_file(filename): |
| return "-tmp-" in filename |
| |
| |
| def listdir_and_filter(dirname, filter_fn): |
| return [path for path in sorted(os.listdir(dirname)) if filter_fn(path)] |
| |
| |
| class SnapshotTest(tf_record_test_base.TFRecordTestBase, |
| parameterized.TestCase): |
| |
| def setUp(self): |
| super(SnapshotTest, self).setUp() |
| tmpdir = self.get_temp_dir() |
| tmpdir = os.path.join(tmpdir, "snapshot") |
| os.mkdir(tmpdir) |
| self._snapshot_dir = tmpdir |
| |
| def tearDown(self): |
| super(SnapshotTest, self).tearDown() |
| shutil.rmtree(self._snapshot_dir) |
| |
| def createTFRecords(self, num_files=10, num_records=100): |
| self._num_files = num_files |
| self._num_records = num_records |
| self._filenames = self._createFiles() |
| |
| def removeTFRecords(self): |
| for filename in self._filenames: |
| os.remove(filename) |
| self._filenames = [] |
| self._num_files = None |
| self._num_records = None |
| |
| def assertDatasetProducesSet(self, dataset, expected): |
| actual = [] |
| next_fn = self.getNext(dataset) |
| for _ in range(len(expected)): |
| elem = self.evaluate(next_fn()) |
| actual.append(elem) |
| self.assertCountEqual(actual, expected) |
| with self.assertRaises(errors.OutOfRangeError): |
| self.evaluate(next_fn()) |
| |
| def assertSnapshotDirectoryContains(self, directory, num_fingerprints, |
| num_runs_per_fingerprint, |
| num_snapshot_shards_per_run): |
| |
| # Ignore the graphdef pbtxts we write for debugging purposes and temporary |
| # files that are an artifact of how TF writes files. |
| dirlist = listdir_and_filter( |
| directory, |
| lambda p: not (is_graphdef_file(p) or is_temp_file(p))) |
| self.assertLen(dirlist, num_fingerprints) |
| |
| for i in range(num_fingerprints): |
| fingerprint_dir = os.path.join(directory, dirlist[i]) |
| fingerprint_dir_list = listdir_and_filter( |
| fingerprint_dir, lambda p: not is_temp_file(p)) |
| self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1) |
| self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint], |
| "snapshot.metadata") |
| |
| for j in range(num_runs_per_fingerprint): |
| run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) |
| run_dirlist = sorted(os.listdir(run_dir)) |
| self.assertLen(run_dirlist, num_snapshot_shards_per_run) |
| |
| file_counter = 0 |
| for filename in run_dirlist: |
| self.assertEqual(filename, "%08d.shard" % file_counter) |
| file_counter += 1 |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCreateSnapshotDataset(self): |
| dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]) |
| dataset.snapshot(path=self._snapshot_dir) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadSnapshotDatasetDefault(self): |
| self.createTFRecords() |
| filenames = self._filenames |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 100) |
| ] |
| |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset, expected) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| self.removeTFRecords() |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset2, expected) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadSnapshotDatasetAutoWriteSnappyRead(self): |
| self.createTFRecords() |
| filenames = self._filenames |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 100) |
| ] |
| |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.snapshot(path=self._snapshot_dir, compression="AUTO") |
| self.assertDatasetProduces(dataset, expected) |
| |
| self.removeTFRecords() |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.snapshot( |
| path=self._snapshot_dir, |
| compression="SNAPPY" |
| ) |
| self.assertDatasetProduces(dataset2, expected) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadSnapshotDatasetCustomShardFn(self): |
| self.createTFRecords() |
| filenames = self._filenames |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 100) |
| ] |
| |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.snapshot( |
| path=self._snapshot_dir, |
| shard_func=lambda _: np.int64(0) |
| ) |
| self.assertDatasetProduces(dataset, expected) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=1) |
| |
| self.removeTFRecords() |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.snapshot( |
| path=self._snapshot_dir, |
| shard_func=lambda _: 0 |
| ) |
| self.assertDatasetProduces(dataset2, expected) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadSnapshotDatasetCustomReaderFn(self): |
| self.createTFRecords() |
| filenames = self._filenames |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 100) |
| ] |
| |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.snapshot( |
| path=self._snapshot_dir, |
| reader_func=( |
| lambda ds: ds.interleave( # pylint:disable=g-long-lambda |
| lambda x: x, |
| cycle_length=4, |
| num_parallel_calls=4))) |
| self.assertDatasetProduces(dataset, expected) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| self.removeTFRecords() |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.snapshot( |
| self._snapshot_dir, |
| reader_func=( |
| lambda ds: ds.interleave( # pylint:disable=g-long-lambda |
| lambda x: x, |
| cycle_length=4, |
| num_parallel_calls=4))) |
| self.assertDatasetProducesSet(dataset2, expected) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testSnapshotDatasetInvalidShardFn(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| with self.assertRaises(TypeError): |
| dataset = dataset.snapshot( |
| path=self._snapshot_dir, |
| shard_func=lambda _: "invalid_fn") |
| next_fn = self.getNext(dataset) |
| self.evaluate(next_fn()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testSnapshotDatasetInvalidReaderFn(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| with self.assertRaises(TypeError): |
| dataset = dataset.snapshot( |
| path=self._snapshot_dir, |
| reader_func=lambda x: x + 1 |
| ) |
| next_fn = self.getNext(dataset) |
| self.evaluate(next_fn()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testRoundtripEmptySnapshot(self): |
| dataset = dataset_ops.Dataset.range(0) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset, []) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=0) |
| |
| dataset2 = dataset_ops.Dataset.range(0) |
| dataset2 = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset2, []) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotDatasetSimple(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset, list(range(1000))) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotDatasetMultipleFingerprints(self): |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset1, list(range(1000))) |
| |
| dataset2 = dataset_ops.Dataset.range(2000) |
| dataset2 = dataset2.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset2, list(range(2000))) |
| |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=2, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self): |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset1, list(range(1000))) |
| dataset2 = dataset_ops.Dataset.range(1000) |
| dataset2 = dataset2.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset2, list(range(1000))) |
| |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self): |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.snapshot(path=self._snapshot_dir) |
| next1 = self.getNext(dataset1) |
| for i in range(500): |
| self.assertEqual(i, self.evaluate(next1())) |
| |
| dataset2 = dataset_ops.Dataset.range(1000) |
| dataset2 = dataset2.snapshot(path=self._snapshot_dir) |
| next2 = self.getNext(dataset2) |
| for i in range(500): |
| self.assertEqual(i, self.evaluate(next2())) |
| |
| for i in range(500, 1000): |
| self.assertEqual(i, self.evaluate(next1())) |
| self.assertEqual(i, self.evaluate(next2())) |
| |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=2, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotCustomShardFunction(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.enumerate() |
| dataset = dataset.snapshot( |
| path=self._snapshot_dir, |
| shard_func=lambda i, _: i % 2 |
| ) |
| dataset = dataset.map(lambda _, elem: elem) |
| self.assertDatasetProduces(dataset, list(range(1000))) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=2) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotDatasetWithTuples(self): |
| dataset1 = dataset_ops.Dataset.range(0, 1000) |
| dataset2 = dataset_ops.Dataset.range(1000, 2000) |
| dataset3 = dataset_ops.Dataset.range(2000, 3000) |
| dataset4 = dataset_ops.Dataset.range(3000, 4000) |
| |
| dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4)) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| |
| expected = list( |
| zip( |
| range(0, 1000), range(1000, 2000), range(2000, 3000), |
| range(3000, 4000))) |
| self.assertDatasetProduces(dataset, expected) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotShuffleSameFingerprint(self): |
| |
| def make_dataset(): |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.shuffle(1000) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| return dataset |
| |
| dataset1 = make_dataset() |
| self.assertDatasetProducesSet(dataset1, list(range(1000))) |
| dataset2 = make_dataset() |
| self.assertDatasetProducesSet(dataset2, list(range(1000))) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadUsingFlatMap(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProduces(dataset, list(range(1000))) |
| flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x) |
| self.assertDatasetProduces(flat_map, list(range(1000))) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadOptimizableUsingFlatMap(self): |
| dataset = dataset_ops.Dataset.range(1000) |
| # Will be optimized into ShuffleAndRepeat. |
| dataset = dataset.shuffle(10) |
| dataset = dataset.repeat(2) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| self.assertDatasetProducesSet(dataset, 2 * list(range(1000))) |
| flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x) |
| self.assertDatasetProducesSet(flat_map, 2 * list(range(1000))) |
| self.assertSnapshotDirectoryContains( |
| self._snapshot_dir, |
| num_fingerprints=1, |
| num_runs_per_fingerprint=1, |
| num_snapshot_shards_per_run=multiprocessing.cpu_count()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testRepeatAndPrefetch(self): |
| """This test reproduces github.com/tensorflow/tensorflow/issues/48903.""" |
| dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| dataset = dataset.shuffle(buffer_size=16) |
| dataset = dataset.batch(16) |
| dataset = dataset.repeat() |
| dataset = dataset.prefetch(1) |
| next_element = self.getNext(dataset) |
| for _ in range(30): |
| self.evaluate(next_element()) |
| |
| |
| class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase, |
| parameterized.TestCase): |
| |
| def setUp(self): |
| super(LegacySnapshotTest, self).setUp() |
| self.removeTFRecords() |
| tmpdir = self.get_temp_dir() |
| tmpdir = os.path.join(tmpdir, "snapshot") |
| os.mkdir(tmpdir) |
| self.snapshot_dir = tmpdir |
| |
| def tearDown(self): |
| super(LegacySnapshotTest, self).tearDown() |
| shutil.rmtree(self.snapshot_dir) |
| |
| def removeTFRecords(self): |
| for filename in self._filenames: |
| os.remove(filename) |
| self._filenames = [] |
| |
| def setUpTFRecord(self, num_files=10, num_records=10): |
| self._num_files = num_files |
| self._num_records = num_records |
| self._filenames = self._createFiles() |
| |
| def makeSnapshotDirectory(self): |
| return self.snapshot_dir |
| |
| def assertSnapshotDirectoryContains(self, directory, num_fingerprints, |
| num_runs_per_fp, num_snapshot_files): |
| # Ignore the graphdef pbtxts we write for debugging purposes and temporary |
| # files that are an artifact of how TF writes files. |
| dirlist = listdir_and_filter( |
| directory, |
| lambda p: not (is_graphdef_file(p) or is_temp_file(p))) |
| self.assertLen(dirlist, num_fingerprints) |
| |
| for i in range(num_fingerprints): |
| fingerprint_dir = os.path.join(directory, dirlist[i]) |
| fingerprint_dir_list = listdir_and_filter( |
| fingerprint_dir, lambda p: not is_temp_file(p)) |
| self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1) |
| self.assertEqual(fingerprint_dir_list[num_runs_per_fp], |
| "snapshot.metadata") |
| |
| for j in range(num_runs_per_fp): |
| run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) |
| run_dirlist = sorted(os.listdir(run_dir)) |
| self.assertLen(run_dirlist, num_snapshot_files) |
| |
| file_counter = 0 |
| for filename in run_dirlist: |
| self.assertEqual(filename, "%08d.snapshot" % file_counter) |
| file_counter += 1 |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteDifferentPipelinesInOneDirectory(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset, list(range(1000))) |
| |
| dataset = dataset_ops.Dataset.range(1001) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset, list(range(1001))) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testWriteSnapshotMultipleSimultaneous(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) |
| next1 = self.getNext(dataset1) |
| |
| dataset2 = dataset_ops.Dataset.range(1000) |
| dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) |
| next2 = self.getNext(dataset2) |
| |
| for i in range(0, 1000): |
| self.assertEqual(i, self.evaluate(next1())) |
| self.assertEqual(i, self.evaluate(next2())) |
| |
| # we check that only one copy of the metadata has been written, and the |
| # one that lost the race would be in passthrough mode. |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testGetNextCreatesDir(self): |
| tmpdir = self.snapshot_dir |
| |
| # We create two iterators but call getNext on only one. |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) |
| next1 = self.getNext(dataset1) |
| |
| dataset2 = dataset_ops.Dataset.range(1001) |
| dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) |
| _ = self.getNext(dataset2) |
| |
| for _ in range(1000): |
| self.evaluate(next1()) |
| |
| # We check that only one directory is created. |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testWriteSnapshotSimpleSuccessful(self, compression): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| self.assertDatasetProduces(dataset, list(range(1000))) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testWriteSnapshotRepeatAfterwards(self, compression): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| dataset = dataset.repeat(10) |
| self.assertDatasetProduces(dataset, list(range(10)) * 10) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testWriteSnapshotMixTypes(self, compression): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(10) |
| |
| def map_fn(x): |
| return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) |
| |
| dataset = dataset.map(map_fn) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| dataset = dataset.repeat(10) |
| |
| expected = [] |
| for i in range(10): |
| expected.append((i, str(i), str(2 * i), 2 * i)) |
| self.assertDatasetProduces(dataset, expected * 10) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testSpecifySnapshotNameWriteAndRead(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, snapshot_name="my_custom_snapshot")) |
| dataset = dataset.repeat(10) |
| self.assertDatasetProduces(dataset, list(range(10)) * 10) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| self.assertTrue( |
| os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot"))) |
| self.assertTrue( |
| os.path.exists( |
| os.path.join(tmpdir, "custom-my_custom_snapshot", "custom"))) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testForcePassthroughMode(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, mode="passthrough")) |
| dataset = dataset.repeat(10) |
| self.assertDatasetProduces(dataset, list(range(10)) * 10) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testForceWriteMode(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write")) |
| dataset = dataset.repeat(10) |
| self.assertDatasetProduces(dataset, list(range(10)) * 10) |
| |
| # We will end up writing 10 different runs. |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testForceReadMode(self): |
| tmpdir = self.snapshot_dir |
| |
| # We write a copy of the snapshot first. |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, mode="write", snapshot_name="my_custom_snapshot")) |
| self.assertDatasetProduces(dataset, list(range(10))) |
| |
| # We move the run to a new name. |
| shutil.move( |
| os.path.join(tmpdir, "custom-my_custom_snapshot"), |
| os.path.join(tmpdir, "custom-my_custom_snapshot_2")) |
| |
| # Even though the snapshot.metadata is pointing to the old run that no |
| # longer exists after we moved, we force it to read from the run we specify. |
| dataset = dataset_ops.Dataset.range(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, mode="read", snapshot_name="my_custom_snapshot_2")) |
| self.assertDatasetProduces(dataset, list(range(10))) |
| |
| # We should still have one snapshot and one run. |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testForceReadNonexistentSnapshot(self): |
| tmpdir = self.snapshot_dir |
| dataset = dataset_ops.Dataset.range(10) |
| with self.assertRaises(errors.NotFoundError): |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="read")) |
| get_next = self.getNext(dataset) |
| self.evaluate(get_next()) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testForceReadNonexistentNamedSnapshot(self): |
| tmpdir = self.snapshot_dir |
| dataset = dataset_ops.Dataset.range(10) |
| with self.assertRaises(errors.NotFoundError): |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot")) |
| get_next = self.getNext(dataset) |
| self.evaluate(get_next()) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testReadSnapshotBackAfterWrite(self, compression): |
| self.setUpTFRecord() |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 10) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| self.assertDatasetProduces(dataset, expected) |
| |
| # remove the original files and try to read the data back only from snapshot |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| self.assertDatasetProduces(dataset2, expected) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadShuffledSnapshotAfterWrite(self): |
| self.setUpTFRecord(num_files=10, num_records=50) |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 50) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100)) |
| self.assertDatasetProduces(dataset, expected) |
| |
| # remove the original files and try to read the data back only from snapshot |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, shard_size_bytes=100, shuffle_on_read=True)) |
| next2 = self.getNext(dataset2) |
| |
| res1 = self.evaluate(next2()) |
| res2 = self.evaluate(next2()) |
| res3 = self.evaluate(next2()) |
| res4 = self.evaluate(next2()) |
| res5 = self.evaluate(next2()) |
| |
| # make sure that we don't read the file back in the same order. |
| self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) |
| |
| # make sure all the elements are still there |
| dataset3 = core_readers._TFRecordDataset(filenames) |
| dataset3 = dataset3.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, shard_size_bytes=100, shuffle_on_read=True)) |
| self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testReadShuffledSnapshotWithSeedAfterWrite(self): |
| self.setUpTFRecord(num_files=10, num_records=50) |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 50) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10)) |
| self.assertDatasetProduces(dataset, expected) |
| |
| # remove the original files and try to read the data back only from snapshot |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, |
| shard_size_bytes=10, |
| shuffle_on_read=True, |
| shuffle_seed=123456)) |
| next2 = self.getNext(dataset2) |
| |
| dataset3 = core_readers._TFRecordDataset(filenames) |
| dataset3 = dataset3.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, |
| shard_size_bytes=10, |
| shuffle_on_read=True, |
| shuffle_seed=123456)) |
| next3 = self.getNext(dataset3) |
| |
| # make sure that the items are read back in the same order for both datasets |
| for _ in range(500): |
| res2 = self.evaluate(next2()) |
| res3 = self.evaluate(next3()) |
| self.assertEqual(res2, res3) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testReadSnapshotParallelAfterWrite(self, compression): |
| self.setUpTFRecord(5, 500) |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 5) |
| for r in range(0, 500) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, |
| shard_size_bytes=1024 * 1024, |
| num_reader_threads=2, |
| reader_buffer_size=10, |
| compression=compression)) |
| self.assertDatasetProduces(dataset, expected, assert_items_equal=True) |
| |
| # remove the original files and try to read the data back only from |
| # snapshot. |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, |
| shard_size_bytes=1024 * 1024, |
| num_reader_threads=2, |
| reader_buffer_size=10, |
| compression=compression)) |
| self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) |
| |
| # Not testing Snappy here because Snappy reads currently require a lot of |
| # memory. |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.times( |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP |
| ]), |
| combinations.combine(threads=2, size=[1, 2]) + |
| combinations.combine(threads=8, size=[1, 4, 8])))) |
| def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads, |
| size): |
| self.setUpTFRecord() |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 10) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, |
| compression=compression, |
| num_writer_threads=threads, |
| writer_buffer_size=size)) |
| self.assertDatasetProduces(dataset, expected) |
| |
| # remove the original files and try to read the data back only from |
| # snapshot |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot(tmpdir, compression=compression)) |
| self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testSameFingerprintWithDifferentInitializationOrder(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset1 = dataset_ops.Dataset.range(0, 100) |
| dataset2 = dataset_ops.Dataset.range(100, 200) |
| dataset3 = dataset_ops.Dataset.range(200, 300) |
| |
| dataset = dataset1.concatenate(dataset2).concatenate(dataset3) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset, list(range(300))) |
| |
| dataset4 = dataset_ops.Dataset.range(200, 300) |
| dataset5 = dataset_ops.Dataset.range(100, 200) |
| dataset6 = dataset_ops.Dataset.range(0, 100) |
| |
| dataset = dataset6.concatenate(dataset5).concatenate(dataset4) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset, list(range(300))) |
| |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testExpiredSnapshotRewrite(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.apply( |
| snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) |
| next1 = self.getNext(dataset1) |
| |
| # Don't finish reading dataset1, so it is never finalized |
| for _ in range(500): |
| self.evaluate(next1()) |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| time.sleep(2) |
| |
| # Creating dataset2 after we run through dataset1 due to eager mode, where |
| # the snapshot state is determined immediately upon dataset creation. We |
| # only want to determine the snapshot state for dataset2 after the first |
| # snapshot has expired. |
| dataset2 = dataset_ops.Dataset.range(1000) |
| dataset2 = dataset2.apply( |
| snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) |
| next2 = self.getNext(dataset2) |
| |
| for _ in range(500): |
| self.evaluate(next2()) |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testSnapshotArgsCreateNewSnapshot(self): |
| tmpdir = self.snapshot_dir |
| |
| dataset1 = dataset_ops.Dataset.range(1000) |
| dataset1 = dataset1.apply( |
| snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000)) |
| next1 = self.getNext(dataset1) |
| |
| for _ in range(1000): |
| self.evaluate(next1()) |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) |
| |
| # Create second snapshot with a different shard_size_bytes |
| dataset2 = dataset_ops.Dataset.range(1000) |
| dataset2 = dataset1.apply( |
| snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000)) |
| next2 = self.getNext(dataset2) |
| |
| for _ in range(1000): |
| self.evaluate(next2()) |
| self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(compression=[ |
| snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, |
| snapshot.COMPRESSION_SNAPPY |
| ]))) |
| def testSpecifyShardSize(self, compression): |
| tmpdir = self.snapshot_dir |
| |
| dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) |
| dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) |
| dataset = dataset.repeat(10) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) |
| next_fn = self.getNext(dataset) |
| |
| for _ in range(10): |
| self.evaluate(next_fn()) |
| |
| num_files = 1 |
| if compression == snapshot.COMPRESSION_NONE: |
| num_files = 3 |
| self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testAdditionalOperationsAfterReadBack(self): |
| self.setUpTFRecord() |
| filenames = self._filenames |
| |
| expected = [ |
| b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 10) |
| ] |
| |
| tmpdir = self.snapshot_dir |
| dataset = core_readers._TFRecordDataset(filenames) |
| dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset, expected) |
| |
| # remove the original files and try to read the data back only from snapshot |
| self.removeTFRecords() |
| |
| dataset2 = core_readers._TFRecordDataset(filenames) |
| dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) |
| self.assertDatasetProduces(dataset2, expected) |
| |
| expected_after = [ |
| b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension |
| for f in range(0, 10) |
| for r in range(0, 10) |
| ] |
| |
| dataset3 = core_readers._TFRecordDataset(filenames) |
| dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir)) |
| dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) |
| self.assertDatasetProduces(dataset3, expected_after) |
| |
| |
| class SnapshotCheckpointTest(checkpoint_test_base.CheckpointTestBase, |
| parameterized.TestCase): |
| |
| def _build_snapshot_dataset(self, repeat=False): |
| |
| def ds_fn(): |
| self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") |
| if not os.path.exists(self._snapshot_dir): |
| os.mkdir(self._snapshot_dir) |
| |
| dataset = dataset_ops.Dataset.range(100) |
| dataset = dataset.snapshot(path=self._snapshot_dir) |
| if repeat: |
| dataset = dataset.repeat(2) |
| return dataset |
| |
| return ds_fn |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCheckpointBeforeEpochEndNoRepeat(self): |
| ds_fn = self._build_snapshot_dataset(repeat=False) |
| outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False) |
| self.assertSequenceEqual(outputs, range(50)) |
| outputs.extend( |
| self.gen_outputs(ds_fn, [], 50, ckpt_saved=True, verify_exhausted=True)) |
| self.assertSequenceEqual(outputs, range(100)) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCheckpointBeforeOneEpochWithReading(self): |
| ds_fn = self._build_snapshot_dataset(repeat=True) |
| |
| # Generate 50 entries from iterator and save checkpoint. |
| outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False) |
| self.assertSequenceEqual(outputs, list(range(50))) |
| |
| # Restore from checkpoint and produce the rest of the elements from the |
| # iterator. |
| t = self.gen_outputs(ds_fn, [], 150, ckpt_saved=True, verify_exhausted=True) |
| outputs.extend(t) |
| self.assertSequenceEqual( |
| outputs, |
| list(range(50)) + list(range(50, 100)) + list(range(100))) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCheckpointBeforeOneEpochThenRunAFewSteps(self): |
| ds_fn = self._build_snapshot_dataset(repeat=False) |
| outputs = self.gen_outputs( |
| ds_fn, [10], 20, verify_exhausted=False, save_checkpoint_at_end=False) |
| self.assertSequenceEqual(outputs, range(20)) |
| |
| outputs = outputs[:10] |
| outputs.extend( |
| self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)) |
| self.assertSequenceEqual(outputs, range(100)) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCheckpointAfterOneEpoch(self): |
| ds_fn = self._build_snapshot_dataset(repeat=True) |
| |
| # Generate 110 entries from iterator and save checkpoint. |
| outputs = self.gen_outputs(ds_fn, [], 110, verify_exhausted=False) |
| self.assertSequenceEqual(outputs, list(range(100)) + list(range(10))) |
| |
| # Restore from checkpoint and produce the rest of the elements from the |
| # iterator. |
| t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True) |
| outputs.extend(t) |
| self.assertSequenceEqual( |
| outputs, |
| list(range(100)) + list(range(10)) + list(range(10, 100))) |
| |
| @combinations.generate(test_base.default_test_combinations()) |
| def testCheckpointAfterOneEpochRunFewSteps(self): |
| ds_fn = self._build_snapshot_dataset(repeat=True) |
| |
| # Generate 120 entries from iterator and save checkpoint at 110. |
| outputs = self.gen_outputs( |
| ds_fn, [110], 120, verify_exhausted=False, save_checkpoint_at_end=False) |
| self.assertSequenceEqual(outputs, list(range(100)) + list(range(20))) |
| |
| # Restore from checkpoint and produce the rest of the elements from the |
| # iterator. |
| outputs = outputs[:110] |
| t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True) |
| outputs.extend(t) |
| self.assertSequenceEqual( |
| outputs, |
| list(range(100)) + list(range(10)) + list(range(10, 100))) |
| |
| |
| class LegacySnapshotCheckpointTest( |
| checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): |
| |
| def _build_snapshot_dataset(self, |
| num_threads=1, |
| repeat=False, |
| pending_snapshot_expiry_seconds=-1, |
| shard_size_bytes=None): |
| |
| def ds_fn(): |
| self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") |
| if not os.path.exists(self.snapshot_dir): |
| os.mkdir(self.snapshot_dir) |
| dataset = dataset_ops.Dataset.range(1000) |
| dataset = dataset.apply( |
| snapshot.legacy_snapshot( |
| self.snapshot_dir, |
| num_writer_threads=num_threads, |
| writer_buffer_size=2 * num_threads, |
| num_reader_threads=num_threads, |
| reader_buffer_size=2 * num_threads, |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, |
| shard_size_bytes=shard_size_bytes)) |
| if repeat: |
| dataset = dataset.repeat(2) |
| return dataset |
| |
| return ds_fn |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testSnapshotBeforeEpochEnd(self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds) |
| outputs = self.gen_outputs(ds_fn, [], 100, verify_exhausted=False) |
| self.assertSequenceEqual(outputs, range(100)) |
| outputs.extend( |
| self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) |
| self.assertSequenceEqual(outputs, range(1000)) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testCheckpointBeforeOneEpochThenRunFewStepsSmallShardMultiThread( |
| self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, |
| shard_size_bytes=100) |
| |
| outputs = [] |
| with ops.Graph().as_default() as g: |
| init_op, get_next_op, saver = self._build_graph(ds_fn) |
| with self.session(graph=g) as sess: |
| self._initialize(init_op, sess) |
| start = 0 |
| end = 100 |
| num_iters = end - start |
| for _ in range(num_iters): |
| outputs.append(sess.run(get_next_op)) |
| self._save(sess, saver) |
| start = 100 |
| end = 400 |
| num_iters = end - start |
| for _ in range(num_iters): |
| outputs.append(sess.run(get_next_op)) |
| self.assertSequenceEqual(outputs, range(400)) |
| |
| outputs = outputs[:100] |
| outputs.extend( |
| self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) |
| self.assertSequenceEqual(outputs, range(1000)) |
| fp_dir_list = os.listdir(self.snapshot_dir) |
| self.assertLen(list(fp_dir_list), 2) |
| for d in fp_dir_list: |
| if not d.endswith("-graph.pbtxt"): |
| fp_dir = os.path.join(self.snapshot_dir, d) |
| run_dir_list = os.listdir(fp_dir) |
| self.assertLen(list(run_dir_list), 2) |
| for e in run_dir_list: |
| if e != "snapshot.metadata": |
| run_dir = os.path.join(fp_dir, e) |
| self.assertLen(list(os.listdir(run_dir)), 258) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testCheckpointBeforeOneEpochThenRunFewSteps( |
| self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds) |
| |
| # Generate 200 entries from iterator but save checkpoint after producing |
| # 100. |
| outputs = self.gen_outputs( |
| ds_fn, [100], 200, verify_exhausted=False, save_checkpoint_at_end=False) |
| self.assertSequenceEqual(outputs, range(200)) |
| |
| outputs = outputs[:100] |
| outputs.extend( |
| self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) |
| self.assertSequenceEqual(outputs, range(1000)) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testCheckpointBeforeOneEpochThenRunFewStepsMultipleThreads( |
| self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| num_threads=2, |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds) |
| |
| # Generate 200 entries from iterator but save checkpoint after producing |
| # 100. |
| outputs = self.gen_outputs( |
| ds_fn, [100], 200, verify_exhausted=False, save_checkpoint_at_end=False) |
| self.assertSequenceEqual(outputs, range(200)) |
| |
| outputs = outputs[:100] |
| outputs.extend( |
| self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) |
| self.assertSequenceEqual(outputs, range(1000)) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testCheckpointAfterOneEpoch(self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| repeat=True, |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds) |
| |
| # Generate 1100 entries from iterator and save checkpoint. |
| outputs = self.gen_outputs(ds_fn, [], 1100, verify_exhausted=False) |
| self.assertSequenceEqual(outputs, list(range(1000)) + list(range(100))) |
| |
| # Restore from checkpoint and produce the rest of the elements from the |
| # iterator. |
| t = self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False) |
| outputs.extend(t) |
| self.assertSequenceEqual( |
| outputs, |
| list(range(1000)) + list(range(100)) + list(range(900))) |
| |
| @combinations.generate( |
| combinations.times( |
| test_base.default_test_combinations(), |
| combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) |
| def testCheckpointAfterOneEpochThenRunFewSteps( |
| self, pending_snapshot_expiry_seconds): |
| ds_fn = self._build_snapshot_dataset( |
| repeat=True, |
| pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds) |
| |
| # Generate 200 entries from iterator but save checkpoint after producing |
| # 100. |
| outputs = self.gen_outputs( |
| ds_fn, [1100], |
| 1200, |
| verify_exhausted=False, |
| save_checkpoint_at_end=False) |
| self.assertSequenceEqual( |
| outputs, |
| list(range(1000)) + list(range(100)) + list(range(100))) |
| |
| outputs = outputs[:1100] |
| t = self.gen_outputs( |
| ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False) |
| outputs.extend(t) |
| self.assertSequenceEqual( |
| outputs, (list(range(1000)) + list(range(100)) + list(range(900)))) |
| |
| |
| if __name__ == "__main__": |
| test.main() |