blob: 0f86e44e28138dd95b5d1696e40ccee282574e83 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the CacheDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class CacheDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase,
parameterized.TestCase):
def setUp(self):
self.range_size = 10
self.num_repeats = 3
self.num_outputs = self.range_size * self.num_repeats
self.cache_file_prefix = 'test'
def make_dataset_fn(self, is_memory):
if is_memory:
filename = ''
else:
filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
def ds_fn():
return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
self.num_repeats)
return ds_fn
def expected_outputs(self):
return list(range(self.range_size)) * self.num_repeats
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointBeforeOneEpoch(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Generate 5 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs.extend(
self.gen_outputs(
ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, self.expected_outputs())
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Generate 8 entries from iterator but save checkpoint after producing 5.
outputs = self.gen_outputs(
ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(8))
outputs = outputs[:5]
outputs.extend(
self.gen_outputs(
ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, self.expected_outputs())
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointAfterOneEpoch(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Generate 15 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs.extend(
self.gen_outputs(
ds_fn, [],
self.num_outputs - 15,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, self.expected_outputs())
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Generate 18 entries from iterator but save checkpoint after producing 15.
outputs = self.gen_outputs(
ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
ds_fn, [],
self.num_outputs - 15,
ckpt_saved=True,
verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Generate 13 entries from iterator but save checkpoint after producing 5.
outputs = self.gen_outputs(
ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
# Since we ran for more than one epoch, the cache was completely written.
# The ckpt was saved when the iterator was in cache-write mode. Test that
# the iterator falls back to read mode after restoring if the cache has
# been completely written.
outputs = list(range(5)) + self.gen_outputs(
ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointUnusedWriterIterator(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Checkpoint before get_next is called even once.
outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
self.assertSequenceEqual(outputs, [])
outputs = self.gen_outputs(
ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Produce 5 elements and checkpoint.
outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint, then produce no elements and checkpoint.
outputs.extend(
self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint and produce rest of the elements.
outputs.extend(
self.gen_outputs(
ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, list(range(10)) * 3)
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testUnusedCheckpointError(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Produce 5 elements and save ckpt.
outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
if is_memory:
outputs = self.gen_outputs(
ds_fn, [], self.num_outputs, verify_exhausted=False)
self.assertSequenceEqual(outputs, self.expected_outputs())
else:
# Since the complete cache has not been written, a new iterator which does
# not restore the checkpoint will throw an error since there is a partial
# cache shard.
with self.assertRaises(errors.AlreadyExistsError):
outputs = self.gen_outputs(
ds_fn, [], self.num_outputs, verify_exhausted=False)
@parameterized.named_parameters(
('Memory', True),
('File', False),
)
def testIgnoreCheckpointIfCacheWritten(self, is_memory):
ds_fn = self.make_dataset_fn(is_memory)
# Produce 15 elements and save ckpt. This will write the complete cache.
outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
# Build the iterator again but do not restore from ckpt. Since the cache
# has already been written we should be able to use it.
outputs = self.gen_outputs(
ds_fn, [], self.num_outputs, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
if __name__ == '__main__':
test.main()