blob: b0d66c271f4a5db10ca6c685198eba6f3dd4f5ad [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset snapshot and related functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
COMPRESSION_GZIP = "GZIP"
COMPRESSION_NONE = None
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A Dataset that captures a snapshot or reads from a snapshot."""
def __init__(self,
input_dataset,
path,
compression=None,
reader_path_prefix=None,
writer_path_prefix=None,
shard_size_bytes=None,
pending_snapshot_expiry_seconds=None,
num_reader_threads=None,
reader_buffer_size=None,
num_writer_threads=None,
writer_buffer_size=None):
self._compression = compression if compression is not None else ""
self._reader_path_prefix = (
reader_path_prefix if reader_path_prefix is not None else "")
self._writer_path_prefix = (
writer_path_prefix if writer_path_prefix is not None else "")
self._shard_size_bytes = (
shard_size_bytes if shard_size_bytes is not None else -1)
self._pending_snapshot_expiry_seconds = (
pending_snapshot_expiry_seconds
if pending_snapshot_expiry_seconds is not None else -1)
self._num_reader_threads = (
num_reader_threads if num_reader_threads is not None else -1)
self._reader_buffer_size = (
reader_buffer_size if reader_buffer_size is not None else -1)
self._num_writer_threads = (
num_writer_threads if num_writer_threads is not None else -1)
self._writer_buffer_size = (
writer_buffer_size if writer_buffer_size is not None else -1)
self._input_dataset = input_dataset
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
variant_tensor = ged_ops.snapshot_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
path=self._path,
compression=self._compression,
reader_path_prefix=self._reader_path_prefix,
writer_path_prefix=self._writer_path_prefix,
shard_size_bytes=self._shard_size_bytes,
pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
num_reader_threads=self._num_reader_threads,
reader_buffer_size=self._reader_buffer_size,
num_writer_threads=self._num_writer_threads,
writer_buffer_size=self._writer_buffer_size,
**self._flat_structure)
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
def snapshot(path,
compression=None,
reader_path_prefix=None,
writer_path_prefix=None,
shard_size_bytes=None,
pending_snapshot_expiry_seconds=None,
num_reader_threads=None,
reader_buffer_size=None,
num_writer_threads=None,
writer_buffer_size=None):
"""Writes to/reads from a snapshot of a dataset.
This function attempts to determine whether a valid snapshot exists at the
`path`, and reads from the snapshot if so. If not, it will run the
preprocessing pipeline as usual, and write out a snapshot of the data
processed for future use.
Args:
path: A directory where we want to save our snapshots and/or read from a
previously saved snapshot.
compression: The type of compression to apply to the Dataset. Currently
supports "GZIP" or None. Defaults to None (no compression).
reader_path_prefix: A prefix to add to the path when reading from snapshots.
Defaults to None.
writer_path_prefix: A prefix to add to the path when writing to snapshots.
Defaults to None.
shard_size_bytes: The size of each shard to be written by the snapshot
dataset op. Defaults to 10 GiB.
pending_snapshot_expiry_seconds: How long to wait (in seconds) before
the snapshot op considers a previously unfinished snapshot to be stale.
num_reader_threads: Number of threads to parallelize reading from snapshot.
Especially useful if compression is turned on since the decompression
operation tends to be intensive. Defaults to 1. If > 1, then this might
introduce non-determinism i.e. the order in which the elements are
read from the snapshot are different from the order they're written.
reader_buffer_size: Maximum number of elements we can prefetch reading from
the snapshot. Defaults to 1. Increasing this might improve performance
but will increase memory consumption.
num_writer_threads: Number of threads to parallelize writing from snapshot.
We'll open up `num_writer_threads` files and write to them in parallel.
Especially useful if compression is turned on since the compression
operation tends to be intensive. Defaults to 1. If > 1, then this might
introduce non-determinism i.e. the order in which the elements are
read from the upstream iterator are different from the order they're
written.
writer_buffer_size: Maximum number of pipeline elements to fill up the
buffer before writing them out using `num_writer_threads`.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _SnapshotDataset(dataset, path, compression, reader_path_prefix,
writer_path_prefix, shard_size_bytes,
pending_snapshot_expiry_seconds, num_reader_threads,
reader_buffer_size, num_writer_threads,
writer_buffer_size)
return _apply_fn