blob: dcf67fb7e571d52d2865fb87162bb7e0fd55b008 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distribution Strategy-related dataset transformations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class _AutoShardDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that shards the `Dataset` automatically.
This dataset takes in an existing dataset and tries to automatically figure
out how to shard the dataset in a multi-worker scenario. Currently, it uses
Grappler to walk up the dataset graph until it finds a reader dataset (e.g.
CSVDataset, TFRecordDataset), then inserts a ShardDataset op before that node
so that each worker only sees some files.
Args:
num_workers: Total number of workers to shard this dataset across.
index: The current worker index (out of the total number of workers) this
dataset is for.
Raises:
NotFoundError: If we cannot find a suitable reader dataset to begin
automatically sharding the dataset.
"""
def __init__(self, input_dataset, num_workers, index):
self._input_dataset = input_dataset
self._element_spec = input_dataset.element_spec
variant_tensor = ged_ops.auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
**self._flat_structure)
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name
return dataset_ops.DatasetV1Adapter(
_AutoShardDataset(input_dataset, num_workers, index))
class _RebatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that divides the batch size by `num_replicas`.
For each batch in the input dataset, the resulting dataset will produce
`num_replicas` minibatches whose sizes add up to the original batch size.
"""
def __init__(self, input_dataset, num_replicas, use_fallback=True):
self._input_dataset = input_dataset
def recalculate_output_shapes(output_shapes):
"""Recalculates the output_shapes after dividing it by num_replicas."""
if len(output_shapes) < 1:
raise ValueError(
"Input shape should have at least one dimension. "
"Perhaps your input dataset is not batched?")
output_dims = [d.value for d in output_shapes.dims]
if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
output_dims[0] = output_dims[0] // num_replicas
else:
# Set the batch dimension to unknown. If the global batch size does not
# divide num_replicas evenly, the minibatches may have different sizes.
output_dims[0] = None
return tensor_shape.TensorShape(output_dims)
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
self._element_spec = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
variant_tensor = ged_ops.rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_replicas=num_replicas,
**self._flat_structure)
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._element_spec
class _RemoteDataset(dataset_ops.DatasetSource):
"""Creates a dataset on a given `device` given a graph def."""
def __init__(self, graph_def, device, element_spec):
self._elem_spec = element_spec
with ops.device(device):
variant_tensor = ged_ops.dataset_from_graph(graph_def)
super(_RemoteDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return self._elem_spec
def replicate(dataset, devices):
"""A transformation that replicates `dataset` onto a list of devices.
Args:
dataset: A `tf.data.Dataset` object.
devices: A list of devices to replicate the dataset on.
Returns:
A dictionary mapping device name to a dataset on that device.
"""
if not isinstance(dataset, dataset_ops.DatasetV2):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
# pylint: disable=protected-access
dataset_device = dataset._variant_tensor.device
datasets = {}
if len(devices) == 1 and devices[0] == dataset_device:
datasets[devices[0]] = dataset
return datasets
with ops.colocate_with(dataset._variant_tensor):
dataset = dataset._apply_options()
allow_stateful = dataset.options().experimental_allow_stateful
graph_def = dataset._as_serialized_graph(allow_stateful=allow_stateful)
for device in devices:
ds = _RemoteDataset(graph_def, device, dataset.element_spec)
datasets[device] = ds
return datasets
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__