blob: 71429d47ffdf5f234af964982984f2fee20dd7ae [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the `SnapshotDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import shutil
import time
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import snapshot
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.kernel_tests import tf_record_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
def is_graphdef_file(filename):
return filename.endswith("-graph.pbtxt")
def is_temp_file(filename):
return "-tmp-" in filename
def listdir_and_filter(dirname, filter_fn):
return [path for path in sorted(os.listdir(dirname)) if filter_fn(path)]
class SnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase):
def setUp(self):
super(SnapshotTest, self).setUp()
tmpdir = self.get_temp_dir()
tmpdir = os.path.join(tmpdir, "snapshot")
os.mkdir(tmpdir)
self._snapshot_dir = tmpdir
def tearDown(self):
super(SnapshotTest, self).tearDown()
shutil.rmtree(self._snapshot_dir)
def createTFRecords(self, num_files=10, num_records=100):
self._num_files = num_files
self._num_records = num_records
self._filenames = self._createFiles()
def removeTFRecords(self):
for filename in self._filenames:
os.remove(filename)
self._filenames = []
self._num_files = None
self._num_records = None
def assertDatasetProducesSet(self, dataset, expected):
actual = []
next_fn = self.getNext(dataset)
for _ in range(len(expected)):
elem = self.evaluate(next_fn())
actual.append(elem)
self.assertCountEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_fn())
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
num_runs_per_fingerprint,
num_snapshot_shards_per_run):
# Ignore the graphdef pbtxts we write for debugging purposes and temporary
# files that are an artifact of how TF writes files.
dirlist = listdir_and_filter(
directory,
lambda p: not (is_graphdef_file(p) or is_temp_file(p)))
self.assertLen(dirlist, num_fingerprints)
for i in range(num_fingerprints):
fingerprint_dir = os.path.join(directory, dirlist[i])
fingerprint_dir_list = listdir_and_filter(
fingerprint_dir, lambda p: not is_temp_file(p))
self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1)
self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint],
"snapshot.metadata")
for j in range(num_runs_per_fingerprint):
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
run_dirlist = sorted(os.listdir(run_dir))
self.assertLen(run_dirlist, num_snapshot_shards_per_run)
file_counter = 0
for filename in run_dirlist:
self.assertEqual(filename, "%08d.shard" % file_counter)
file_counter += 1
@combinations.generate(test_base.default_test_combinations())
def testCreateSnapshotDataset(self):
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
dataset.snapshot(path=self._snapshot_dir)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetDefault(self):
self.createTFRecords()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetAutoWriteSnappyRead(self):
self.createTFRecords()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.snapshot(path=self._snapshot_dir, compression="AUTO")
self.assertDatasetProduces(dataset, expected)
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.snapshot(
path=self._snapshot_dir,
compression="SNAPPY"
)
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetCustomShardFn(self):
self.createTFRecords()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.snapshot(
path=self._snapshot_dir,
shard_func=lambda _: np.int64(0)
)
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=1)
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.snapshot(
path=self._snapshot_dir,
shard_func=lambda _: 0
)
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetCustomReaderFn(self):
self.createTFRecords()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.snapshot(
path=self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
num_parallel_calls=4)))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.snapshot(
self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
num_parallel_calls=4)))
self.assertDatasetProducesSet(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testSnapshotDatasetInvalidShardFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
dataset = dataset.snapshot(
path=self._snapshot_dir,
shard_func=lambda _: "invalid_fn")
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
def testSnapshotDatasetInvalidReaderFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
dataset = dataset.snapshot(
path=self._snapshot_dir,
reader_func=lambda x: x + 1
)
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
def testRoundtripEmptySnapshot(self):
dataset = dataset_ops.Dataset.range(0)
dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, [])
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=0)
dataset2 = dataset_ops.Dataset.range(0)
dataset2 = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, [])
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSimple(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetMultipleFingerprints(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(2000)
dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, list(range(2000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=2,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset2, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.snapshot(path=self._snapshot_dir)
next1 = self.getNext(dataset1)
for i in range(500):
self.assertEqual(i, self.evaluate(next1()))
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.snapshot(path=self._snapshot_dir)
next2 = self.getNext(dataset2)
for i in range(500):
self.assertEqual(i, self.evaluate(next2()))
for i in range(500, 1000):
self.assertEqual(i, self.evaluate(next1()))
self.assertEqual(i, self.evaluate(next2()))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=2,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotCustomShardFunction(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.enumerate()
dataset = dataset.snapshot(
path=self._snapshot_dir,
shard_func=lambda i, _: i % 2
)
dataset = dataset.map(lambda _, elem: elem)
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=2)
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetWithTuples(self):
dataset1 = dataset_ops.Dataset.range(0, 1000)
dataset2 = dataset_ops.Dataset.range(1000, 2000)
dataset3 = dataset_ops.Dataset.range(2000, 3000)
dataset4 = dataset_ops.Dataset.range(3000, 4000)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
dataset = dataset.snapshot(path=self._snapshot_dir)
expected = list(
zip(
range(0, 1000), range(1000, 2000), range(2000, 3000),
range(3000, 4000)))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotShuffleSameFingerprint(self):
def make_dataset():
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.shuffle(1000)
dataset = dataset.snapshot(path=self._snapshot_dir)
return dataset
dataset1 = make_dataset()
self.assertDatasetProducesSet(dataset1, list(range(1000)))
dataset2 = make_dataset()
self.assertDatasetProducesSet(dataset2, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testReadUsingFlatMap(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProduces(dataset, list(range(1000)))
flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
self.assertDatasetProduces(flat_map, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testReadOptimizableUsingFlatMap(self):
dataset = dataset_ops.Dataset.range(1000)
# Will be optimized into ShuffleAndRepeat.
dataset = dataset.shuffle(10)
dataset = dataset.repeat(2)
dataset = dataset.snapshot(path=self._snapshot_dir)
self.assertDatasetProducesSet(dataset, 2 * list(range(1000)))
flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
self.assertDatasetProducesSet(flat_map, 2 * list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.snapshot(path=self._snapshot_dir)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())
class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase):
def setUp(self):
super(LegacySnapshotTest, self).setUp()
self.removeTFRecords()
tmpdir = self.get_temp_dir()
tmpdir = os.path.join(tmpdir, "snapshot")
os.mkdir(tmpdir)
self.snapshot_dir = tmpdir
def tearDown(self):
super(LegacySnapshotTest, self).tearDown()
shutil.rmtree(self.snapshot_dir)
def removeTFRecords(self):
for filename in self._filenames:
os.remove(filename)
self._filenames = []
def setUpTFRecord(self, num_files=10, num_records=10):
self._num_files = num_files
self._num_records = num_records
self._filenames = self._createFiles()
def makeSnapshotDirectory(self):
return self.snapshot_dir
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
num_runs_per_fp, num_snapshot_files):
# Ignore the graphdef pbtxts we write for debugging purposes and temporary
# files that are an artifact of how TF writes files.
dirlist = listdir_and_filter(
directory,
lambda p: not (is_graphdef_file(p) or is_temp_file(p)))
self.assertLen(dirlist, num_fingerprints)
for i in range(num_fingerprints):
fingerprint_dir = os.path.join(directory, dirlist[i])
fingerprint_dir_list = listdir_and_filter(
fingerprint_dir, lambda p: not is_temp_file(p))
self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1)
self.assertEqual(fingerprint_dir_list[num_runs_per_fp],
"snapshot.metadata")
for j in range(num_runs_per_fp):
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
run_dirlist = sorted(os.listdir(run_dir))
self.assertLen(run_dirlist, num_snapshot_files)
file_counter = 0
for filename in run_dirlist:
self.assertEqual(filename, "%08d.snapshot" % file_counter)
file_counter += 1
@combinations.generate(test_base.default_test_combinations())
def testWriteDifferentPipelinesInOneDirectory(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset, list(range(1000)))
dataset = dataset_ops.Dataset.range(1001)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset, list(range(1001)))
self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotMultipleSimultaneous(self):
tmpdir = self.snapshot_dir
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
next1 = self.getNext(dataset1)
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
next2 = self.getNext(dataset2)
for i in range(0, 1000):
self.assertEqual(i, self.evaluate(next1()))
self.assertEqual(i, self.evaluate(next2()))
# we check that only one copy of the metadata has been written, and the
# one that lost the race would be in passthrough mode.
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testGetNextCreatesDir(self):
tmpdir = self.snapshot_dir
# We create two iterators but call getNext on only one.
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
next1 = self.getNext(dataset1)
dataset2 = dataset_ops.Dataset.range(1001)
dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
_ = self.getNext(dataset2)
for _ in range(1000):
self.evaluate(next1())
# We check that only one directory is created.
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testWriteSnapshotSimpleSuccessful(self, compression):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testWriteSnapshotRepeatAfterwards(self, compression):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testWriteSnapshotMixTypes(self, compression):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
def map_fn(x):
return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x)
dataset = dataset.map(map_fn)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
dataset = dataset.repeat(10)
expected = []
for i in range(10):
expected.append((i, str(i), str(2 * i), 2 * i))
self.assertDatasetProduces(dataset, expected * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testSpecifySnapshotNameWriteAndRead(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, snapshot_name="my_custom_snapshot"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
self.assertTrue(
os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot")))
self.assertTrue(
os.path.exists(
os.path.join(tmpdir, "custom-my_custom_snapshot", "custom")))
@combinations.generate(test_base.default_test_combinations())
def testForcePassthroughMode(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, mode="passthrough"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)
@combinations.generate(test_base.default_test_combinations())
def testForceWriteMode(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
# We will end up writing 10 different runs.
self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)
@combinations.generate(test_base.default_test_combinations())
def testForceReadMode(self):
tmpdir = self.snapshot_dir
# We write a copy of the snapshot first.
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir, mode="write", snapshot_name="my_custom_snapshot"))
self.assertDatasetProduces(dataset, list(range(10)))
# We move the run to a new name.
shutil.move(
os.path.join(tmpdir, "custom-my_custom_snapshot"),
os.path.join(tmpdir, "custom-my_custom_snapshot_2"))
# Even though the snapshot.metadata is pointing to the old run that no
# longer exists after we moved, we force it to read from the run we specify.
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir, mode="read", snapshot_name="my_custom_snapshot_2"))
self.assertDatasetProduces(dataset, list(range(10)))
# We should still have one snapshot and one run.
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testForceReadNonexistentSnapshot(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
with self.assertRaises(errors.NotFoundError):
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="read"))
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testForceReadNonexistentNamedSnapshot(self):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
with self.assertRaises(errors.NotFoundError):
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot"))
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testReadSnapshotBackAfterWrite(self, compression):
self.setUpTFRecord()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 10)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadShuffledSnapshotAfterWrite(self):
self.setUpTFRecord(num_files=10, num_records=50)
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 50)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(
tmpdir, shard_size_bytes=100, shuffle_on_read=True))
next2 = self.getNext(dataset2)
res1 = self.evaluate(next2())
res2 = self.evaluate(next2())
res3 = self.evaluate(next2())
res4 = self.evaluate(next2())
res5 = self.evaluate(next2())
# make sure that we don't read the file back in the same order.
self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5])
# make sure all the elements are still there
dataset3 = core_readers._TFRecordDataset(filenames)
dataset3 = dataset3.apply(
snapshot.legacy_snapshot(
tmpdir, shard_size_bytes=100, shuffle_on_read=True))
self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testReadShuffledSnapshotWithSeedAfterWrite(self):
self.setUpTFRecord(num_files=10, num_records=50)
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 50)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(
tmpdir,
shard_size_bytes=10,
shuffle_on_read=True,
shuffle_seed=123456))
next2 = self.getNext(dataset2)
dataset3 = core_readers._TFRecordDataset(filenames)
dataset3 = dataset3.apply(
snapshot.legacy_snapshot(
tmpdir,
shard_size_bytes=10,
shuffle_on_read=True,
shuffle_seed=123456))
next3 = self.getNext(dataset3)
# make sure that the items are read back in the same order for both datasets
for _ in range(500):
res2 = self.evaluate(next2())
res3 = self.evaluate(next3())
self.assertEqual(res2, res3)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testReadSnapshotParallelAfterWrite(self, compression):
self.setUpTFRecord(5, 500)
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 5)
for r in range(0, 500)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir,
shard_size_bytes=1024 * 1024,
num_reader_threads=2,
reader_buffer_size=10,
compression=compression))
self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
# remove the original files and try to read the data back only from
# snapshot.
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(
tmpdir,
shard_size_bytes=1024 * 1024,
num_reader_threads=2,
reader_buffer_size=10,
compression=compression))
self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
# Not testing Snappy here because Snappy reads currently require a lot of
# memory.
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.times(
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP
]),
combinations.combine(threads=2, size=[1, 2]) +
combinations.combine(threads=8, size=[1, 4, 8]))))
def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
size):
self.setUpTFRecord()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 10)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir,
compression=compression,
num_writer_threads=threads,
writer_buffer_size=size))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from
# snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testSameFingerprintWithDifferentInitializationOrder(self):
tmpdir = self.snapshot_dir
dataset1 = dataset_ops.Dataset.range(0, 100)
dataset2 = dataset_ops.Dataset.range(100, 200)
dataset3 = dataset_ops.Dataset.range(200, 300)
dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset, list(range(300)))
dataset4 = dataset_ops.Dataset.range(200, 300)
dataset5 = dataset_ops.Dataset.range(100, 200)
dataset6 = dataset_ops.Dataset.range(0, 100)
dataset = dataset6.concatenate(dataset5).concatenate(dataset4)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset, list(range(300)))
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testExpiredSnapshotRewrite(self):
tmpdir = self.snapshot_dir
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(
snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
next1 = self.getNext(dataset1)
# Don't finish reading dataset1, so it is never finalized
for _ in range(500):
self.evaluate(next1())
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
time.sleep(2)
# Creating dataset2 after we run through dataset1 due to eager mode, where
# the snapshot state is determined immediately upon dataset creation. We
# only want to determine the snapshot state for dataset2 after the first
# snapshot has expired.
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.apply(
snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
next2 = self.getNext(dataset2)
for _ in range(500):
self.evaluate(next2())
self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)
@combinations.generate(test_base.default_test_combinations())
def testSnapshotArgsCreateNewSnapshot(self):
tmpdir = self.snapshot_dir
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(
snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000))
next1 = self.getNext(dataset1)
for _ in range(1000):
self.evaluate(next1())
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
# Create second snapshot with a different shard_size_bytes
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset1.apply(
snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000))
next2 = self.getNext(dataset2)
for _ in range(1000):
self.evaluate(next2())
self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testSpecifyShardSize(self, compression):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
dataset = dataset.repeat(10)
dataset = dataset.apply(
snapshot.legacy_snapshot(
tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression))
next_fn = self.getNext(dataset)
for _ in range(10):
self.evaluate(next_fn())
num_files = 1
if compression == snapshot.COMPRESSION_NONE:
num_files = 3
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)
@combinations.generate(test_base.default_test_combinations())
def testAdditionalOperationsAfterReadBack(self):
self.setUpTFRecord()
filenames = self._filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 10)
]
tmpdir = self.snapshot_dir
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
self.assertDatasetProduces(dataset2, expected)
expected_after = [
b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 10)
]
dataset3 = core_readers._TFRecordDataset(filenames)
dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir))
dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
self.assertDatasetProduces(dataset3, expected_after)
class SnapshotCheckpointTest(checkpoint_test_base.CheckpointTestBase,
parameterized.TestCase):
def _build_snapshot_dataset(self, repeat=False):
def ds_fn():
self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
if not os.path.exists(self._snapshot_dir):
os.mkdir(self._snapshot_dir)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.snapshot(path=self._snapshot_dir)
if repeat:
dataset = dataset.repeat(2)
return dataset
return ds_fn
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeEpochEndNoRepeat(self):
ds_fn = self._build_snapshot_dataset(repeat=False)
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(50))
outputs.extend(
self.gen_outputs(ds_fn, [], 50, ckpt_saved=True, verify_exhausted=True))
self.assertSequenceEqual(outputs, range(100))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeOneEpochWithReading(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 50 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(50)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
t = self.gen_outputs(ds_fn, [], 150, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(50)) + list(range(50, 100)) + list(range(100)))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeOneEpochThenRunAFewSteps(self):
ds_fn = self._build_snapshot_dataset(repeat=False)
outputs = self.gen_outputs(
ds_fn, [10], 20, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(20))
outputs = outputs[:10]
outputs.extend(
self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True))
self.assertSequenceEqual(outputs, range(100))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointAfterOneEpoch(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 110 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 110, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(100)) + list(range(10)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(100)) + list(range(10)) + list(range(10, 100)))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointAfterOneEpochRunFewSteps(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 120 entries from iterator and save checkpoint at 110.
outputs = self.gen_outputs(
ds_fn, [110], 120, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(100)) + list(range(20)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs = outputs[:110]
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(100)) + list(range(10)) + list(range(10, 100)))
class LegacySnapshotCheckpointTest(
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
def _build_snapshot_dataset(self,
num_threads=1,
repeat=False,
pending_snapshot_expiry_seconds=-1,
shard_size_bytes=None):
def ds_fn():
self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
if not os.path.exists(self.snapshot_dir):
os.mkdir(self.snapshot_dir)
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(
snapshot.legacy_snapshot(
self.snapshot_dir,
num_writer_threads=num_threads,
writer_buffer_size=2 * num_threads,
num_reader_threads=num_threads,
reader_buffer_size=2 * num_threads,
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
shard_size_bytes=shard_size_bytes))
if repeat:
dataset = dataset.repeat(2)
return dataset
return ds_fn
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testSnapshotBeforeEpochEnd(self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)
outputs = self.gen_outputs(ds_fn, [], 100, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(100))
outputs.extend(
self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(1000))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testCheckpointBeforeOneEpochThenRunFewStepsSmallShardMultiThread(
self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
shard_size_bytes=100)
outputs = []
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(ds_fn)
with self.session(graph=g) as sess:
self._initialize(init_op, sess)
start = 0
end = 100
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
self._save(sess, saver)
start = 100
end = 400
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
self.assertSequenceEqual(outputs, range(400))
outputs = outputs[:100]
outputs.extend(
self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(1000))
fp_dir_list = os.listdir(self.snapshot_dir)
self.assertLen(list(fp_dir_list), 2)
for d in fp_dir_list:
if not d.endswith("-graph.pbtxt"):
fp_dir = os.path.join(self.snapshot_dir, d)
run_dir_list = os.listdir(fp_dir)
self.assertLen(list(run_dir_list), 2)
for e in run_dir_list:
if e != "snapshot.metadata":
run_dir = os.path.join(fp_dir, e)
self.assertLen(list(os.listdir(run_dir)), 258)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testCheckpointBeforeOneEpochThenRunFewSteps(
self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)
# Generate 200 entries from iterator but save checkpoint after producing
# 100.
outputs = self.gen_outputs(
ds_fn, [100], 200, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(200))
outputs = outputs[:100]
outputs.extend(
self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(1000))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testCheckpointBeforeOneEpochThenRunFewStepsMultipleThreads(
self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
num_threads=2,
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)
# Generate 200 entries from iterator but save checkpoint after producing
# 100.
outputs = self.gen_outputs(
ds_fn, [100], 200, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(200))
outputs = outputs[:100]
outputs.extend(
self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(1000))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testCheckpointAfterOneEpoch(self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
repeat=True,
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)
# Generate 1100 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 1100, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(1000)) + list(range(100)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
t = self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(1000)) + list(range(100)) + list(range(900)))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
def testCheckpointAfterOneEpochThenRunFewSteps(
self, pending_snapshot_expiry_seconds):
ds_fn = self._build_snapshot_dataset(
repeat=True,
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)
# Generate 200 entries from iterator but save checkpoint after producing
# 100.
outputs = self.gen_outputs(
ds_fn, [1100],
1200,
verify_exhausted=False,
save_checkpoint_at_end=False)
self.assertSequenceEqual(
outputs,
list(range(1000)) + list(range(100)) + list(range(100)))
outputs = outputs[:1100]
t = self.gen_outputs(
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)
outputs.extend(t)
self.assertSequenceEqual(
outputs, (list(range(1000)) + list(range(100)) + list(range(900))))
if __name__ == "__main__":
test.main()