blob: 34e6c397602bae02d62a3950cace64c154c4a33c [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
def get_distributed_dataset(dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
This is a common function that is used by all strategies to return the right
tf.data.Dataset wrapped instance depending on the `dataset` argument type.
Args:
dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
Returns:
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
"""
if isinstance(dataset, dataset_ops.DatasetV1):
return DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
else:
return DistributedDataset(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
def get_distributed_datasets_from_function(dataset_fn,
input_workers,
input_contexts,
strategy):
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
This is a common function that is used by all strategies to return the right
tf.data.Dataset wrapped instance depending on if we are in graph or eager
mode.
Args:
dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2
instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `dataset_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
Returns:
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
"""
if ops.executing_eagerly_outside_functions():
return DistributedDatasetsFromFunction(
dataset_fn,
input_workers,
input_contexts,
strategy)
else:
return DistributedDatasetsFromFunctionV1(
dataset_fn,
input_workers,
input_contexts,
strategy)
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
"""Initialize an `InputWorkers` object.
Args:
device_map: A `DeviceMap` with the computation devices fed by the
input workers.
worker_device_pairs: A sequence of pairs:
`(input device, a tuple of compute devices fed by that input device)`.
logical_device: The logical device of `device_map` to feed.
"""
self._device_map = device_map
self._logical_device = logical_device
if worker_device_pairs is None:
worker_device_pairs = ((
device_util.canonicalize("/device:CPU:0"),
device_map.logical_to_actual_devices(logical_device)),)
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
for _, f in worker_device_pairs)
flattened = tuple(d for l in self._fed_devices for d in l)
assert (flattened ==
device_map.logical_to_actual_devices(logical_device)), (
"flattened: %s logical device %d: %s" %
(flattened, logical_device,
device_map.logical_to_actual_devices(logical_device)))
@property
def device_map(self):
return self._device_map
@property
def logical_device(self):
return self._logical_device
@property
def num_workers(self):
return len(self._input_worker_devices)
@property
def worker_devices(self):
return self._input_worker_devices
def compute_devices_for_worker(self, worker_index):
return self._fed_devices[worker_index]
def __repr__(self):
devices = self.worker_devices
debug_repr = ",\n".join(" %d %s: %s" %
(i, devices[i], self._fed_devices[i])
for i in range(len(devices)))
return "%s:{\n%s\n device_map: %s}" % (
self.__class__.__name__, debug_repr, self._device_map)
def _get_next_as_optional(iterator, strategy, name=None):
"""Returns an empty dataset indicator and the next input from the iterator."""
replicas = []
worker_has_values = []
worker_devices = []
for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
worker_has_value, next_element = (
iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access
# Collective all-reduce requires explict devices for inputs.
with ops.device("/cpu:0"):
# Converting to integers for all-reduce.
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
worker_devices.append(worker_has_value.device)
worker_has_values.append(worker_has_value)
# Make `replicas` a flat list of values across all replicas.
replicas.append(next_element)
# Run an all-reduce to see whether any worker has values.
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
# cases.
if getattr(strategy.extended, "_support_per_replica_values", True):
worker_has_values = values.PerReplica(
values.WorkerDeviceMap(
worker_devices,
num_replicas_per_worker=len(
strategy.extended._input_workers._input_worker_devices)), # pylint: disable=protected-access
worker_has_values)
global_has_value = strategy.reduce(
reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
else:
assert len(worker_has_values) == 1
global_has_value = worker_has_values[0]
global_has_value = array_ops.reshape(
math_ops.cast(global_has_value, dtypes.bool), [])
return global_has_value, replicas
class DistributedIterator(object):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators, strategy):
static_shape = True
for iterator in iterators:
if not isinstance(iterator, _SingleWorkerDatasetIterator):
continue
flattened_shapes = nest.flatten(iterator.output_shapes)
for output_shape in flattened_shapes:
if not output_shape.is_fully_defined():
static_shape = False
break
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is
# always set to False in CollectiveAllReduceStrategy. We want to have a way
# to distinguish multi workers/single worker between graph, so we can enable
# the behavior in single worker case.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape
else:
self._enable_get_next_as_optional = False
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
self._iterators = iterators
self._input_workers = input_workers
self._strategy = strategy
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except errors.OutOfRangeError:
raise StopIteration
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
if not self._enable_get_next_as_optional:
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(
self._iterators[i].get_next_as_list_deprecated(new_name))
return values.regroup(self._input_workers.device_map, replicas)
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
"""This function will throw an OutOfRange error."""
# As this will be only called when there is no data left, so calling
# get_next() will trigger an OutOfRange error.
data = self._iterators[worker_index].get_next(device)
out_of_range_replicas.append(data)
return data
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
results = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
devices = self._input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
# pylint: disable=undefined-loop-variable
# pylint: disable=cell-var-from-loop
# It is fine for the lambda to capture variables from the loop as
# the lambda is executed in the loop as well.
result = control_flow_ops.cond(global_has_value,
lambda: replicas[i][j],
lambda: out_of_range_fn(i, device))
# pylint: enable=cell-var-from-loop
# pylint: enable=undefined-loop-variable
results.append(result)
replicas = results
# Some dimensions in `replicas` will become unknown after we conditionally
# return the real tensors or the dummy tensors. We fix the input shapes by
# using the shapes from `out_of_range_replicas` because it is calling
# get_next() inside.
flattened_replicas = nest.flatten(replicas)
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
flattened_replicas[i].set_shape(replica_data.get_shape())
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
return values.regroup(self._input_workers.device_map, replicas)
# We need a private initializer method for re-initializing multidevice
# iterators when used with Keras training loops. If we don't reinitialize the
# iterator we run into memory leak issues (b/123315763).
@property
def _initializer(self):
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return control_flow_ops.group(init_ops)
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for tf.data.DatasetV1."""
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
# with tf.data iterator APIs.
def initialize(self):
"""Initialze underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
return super(DistributedIteratorV1, self)._initializer
@property
def initializer(self):
return self.initialize()
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_classes(self):
return self._iterators[0].output_classes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_shapes(self):
return self._iterators[0].output_shapes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_types(self):
return self._iterators[0].output_types
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
class _IterableInput(object):
"""Base class for iterable inputs for distribution strategies."""
def __init__(self, input_workers):
assert isinstance(input_workers, InputWorkers)
self._input_workers = input_workers
def __iter__(self):
raise NotImplementedError("must be implemented in descendants")
def _autograph_for_loop(self, extra_test, body, init_state):
"""Overload of for..in statement that iterates over the input."""
if extra_test is not None:
raise NotImplementedError(
"break and return statements are not yet supported in "
"for ... in distributed input loops.")
def reduce_body(state, iterate):
new_state = body(iterate, *state)
return new_state
if init_state:
return self.reduce(init_state, reduce_body)
# TODO(anjalisridhar): This is a workaround for Dataset.reduce not allowing
# empty state tensors - create a dummy state variable that remains unused.
# Identify if we need this workaround and remove if unnecessary.
def reduce_body_with_dummy_state(state, iterate):
reduce_body((), iterate)
return state
self.reduce((constant_op.constant(0),), reduce_body_with_dummy_state)
return ()
def reduce(self, initial_state, reduce_fn):
"""Execute a `reduce_fn` over all the elements of the input."""
iterator = iter(self)
has_data, data = _get_next_as_optional(iterator, self._strategy)
def cond(has_data, data, state):
del data, state # Unused.
return has_data
def loop_body(has_data, data, state):
"""Executes `reduce_fn` in a loop till the dataset is empty."""
del has_data # Unused.
# data is list of lists here. where each list corresponds to one worker.
# TODO(b/130570614): Add support for the multiworker and TPU pods use
# case.
if self._input_workers.num_workers == 1:
data = data[0]
else:
raise ValueError("Dataset iteration within a tf.function is"
" not supported for multiple workers.")
per_replica_data = values.regroup(self._input_workers.device_map, data)
state = reduce_fn(state, per_replica_data)
has_data, data = _get_next_as_optional(iterator, self._strategy)
return has_data, data, state
has_data, data, final_state = control_flow_ops.while_loop(
cond, loop_body, [has_data, data, initial_state], parallel_iterations=1)
return final_state
class DistributedDataset(_IterableInput):
"""Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Distribute the dataset on all workers.
If `split_batch_by` is not None, we "split" each batch of the dataset by
`split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
"""
super(DistributedDataset, self).__init__(input_workers=input_workers)
# We clone and shard the dataset on each worker. The current setup tries to
# shard the dataset by files if possible so that each worker sees a
# different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset.
if split_batch_by:
try:
# pylint: disable=protected-access
with ops.colocate_with(dataset._variant_tensor):
dataset = distribute._RebatchDataset(dataset, split_batch_by)
# Add a prefetch to pipeline rebatching for performance.
# TODO(rachelim): Instead of inserting an extra prefetch stage here,
# leverage static graph rewrites to insert _RebatchDataset before
# the final `prefetch` if it exists.
dataset = dataset.prefetch(split_batch_by)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please "
"the tf.distribute.Strategy guide. {}".format(
split_batch_by, e)),
sys.exc_info()[2])
else:
raise
# TODO(b/138745411): Remove once stateful transformations are supported.
options = dataset_ops.Options()
options.experimental_distribute._make_stateless = True # pylint: disable=protected-access
dataset = dataset.with_options(options)
self._cloned_datasets = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
dataset = input_ops.auto_shard_dataset(dataset,
input_context.num_input_pipelines,
input_context.input_pipeline_id)
self._cloned_datasets.append(dataset)
else:
replicated_ds = distribute.replicate(dataset,
input_workers.worker_devices)
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
cloned_dataset = replicated_ds[worker]
cloned_dataset = cloned_dataset.with_options(dataset.options())
cloned_dataset = input_ops.auto_shard_dataset(
cloned_dataset, len(input_workers.worker_devices), i)
self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers
# TODO(anjalisridhar): Identify if we need to set this property on the
# iterator.
self.element_spec = dataset.element_spec
self._strategy = strategy
def __iter__(self):
if not (context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
iterator.element_spec = self.element_spec # pylint: disable=protected-access
return iterator
class DistributedDatasetV1(DistributedDataset):
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
self._input_workers = input_workers
super(DistributedDatasetV1, self).__init__(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
def make_one_shot_iterator(self):
"""Get a one time use iterator for DistributedDatasetV1.
Note: This API is deprecated. Please use `for ... in dataset:` to iterate
over the dataset or `iter` to create an iterator.
Returns:
A DistributedIteratorV1 instance.
"""
return self._make_one_shot_iterator()
def _make_one_shot_iterator(self):
"""Get an iterator for DistributedDatasetV1."""
# Graph mode with one shot iterator is disabled because we have to call
# `initialize` on the iterator which is only required if we are using a
# tf.distribute strategy.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
return self._get_iterator()
def make_initializable_iterator(self):
"""Get an initializable iterator for DistributedDatasetV1.
Note: This API is deprecated. Please use
`tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
initializable iterator.
Returns:
A DistributedIteratorV1 instance.
"""
return self._make_initializable_iterator()
def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument
"""Get an initializable iterator for DistributedDatasetV1."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `iter()` instead.")
return self._get_iterator()
def _get_iterator(self):
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
iterator.element_spec = self.element_spec # pylint: disable=protected-access
return iterator
# TODO(priyag): Add other replication modes.
class DistributedDatasetsFromFunction(_IterableInput):
"""Inputs created from dataset function."""
def __init__(self, dataset_fn, input_workers, input_contexts, strategy):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a `Dataset` given an `InputContext`.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `dataset_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
"""
super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
self._dataset_fn = dataset_fn
self._input_workers = input_workers
self._input_contexts = input_contexts
self._strategy = strategy
def __iter__(self):
if not (context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
iterators = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
return DistributedIterator(self._input_workers, iterators, self._strategy)
class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
"""Inputs created from dataset function."""
def _make_initializable_iterator(self, shared_name=None):
"""Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
del shared_name # Unused
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `iter()` instead.")
return self._get_iterator()
def _make_one_shot_iterator(self):
"""Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
# Graph mode with one shot iterator is disabled because we have to call
# `initialize` on the iterator which is only required if we are using a
# tf.distribute strategy.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
return self._get_iterator()
def _get_iterator(self):
iterators = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
return DistributedIteratorV1(self._input_workers, iterators, self._strategy)
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
# APIs.
class InputFunctionIterator(DistributedIteratorV1):
"""Iterator created from input function."""
def __init__(self, input_fn, input_workers, input_contexts, strategy):
"""Make an iterator for input provided via an input function.
Currently implements PER_WORKER mode, in which the `input_fn` is called
once on each worker.
TODO(priyag): Add other replication modes.
Args:
input_fn: Input function that returns a `tf.data.Dataset` object.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `input_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
"""
assert isinstance(input_workers, InputWorkers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
iterators = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
result = input_fn(ctx)
devices = input_workers.compute_devices_for_worker(i)
if isinstance(result, dataset_ops.DatasetV2):
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
elif callable(result):
iterator = _SingleWorkerCallableIterator(result, worker, devices)
else:
raise ValueError(
"input_fn must return a tf.data.Dataset or a callable.")
iterators.append(iterator)
super(InputFunctionIterator, self).__init__(input_workers, iterators,
strategy)
# TODO(anjalisridhar): This class will soon be removed and users should move
# to using DistributedIterator.
class DatasetIterator(DistributedIteratorV1):
"""Iterator created from input dataset."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
"""
dist_dataset = DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(
input_workers,
worker_iterators, # pylint: disable=protected-access
strategy)
self.element_spec = dist_dataset.element_spec # pylint: disable=protected-access
def _dummy_tensor_fn(value_structure):
"""A function to create dummy tensors from `value_structure`."""
def create_dummy_tensor(feature_shape, feature_type):
"""Create a dummy tensor with possible batch dimensions set to 0."""
# Ideally we should set the batch dimension to 0, however as in
# DistributionStrategy we don't know the batch dimension, we try to
# guess it as much as possible. If the feature has unknown dimensions, we
# will set them to 0. If the feature shape is already static, we guess the
# first dimension as batch dimension and set it to 0.
dims = []
for dim in feature_shape.dims:
if dim.value is None:
dims.append(tensor_shape.Dimension(0))
else:
dims.append(dim)
if feature_shape.is_fully_defined() and dims:
dims[0] = tensor_shape.Dimension(0)
# Create the dummy tensor.
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
return dummy_tensor
result = []
# pylint: disable=protected-access
for feature_shape, feature_type in zip(
structure.get_flat_tensor_shapes(value_structure),
structure.get_flat_tensor_types(value_structure)):
result.append(create_dummy_tensor(feature_shape, feature_type))
return nest.pack_sequence_as(value_structure, result)
# pylint: enable=protected-access
class _SingleWorkerDatasetIterator(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`MultiDeviceIterator` is used to prefetch input to the devices on the
given worker.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._make_iterator()
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
return self._iterator.get_next(device)
def get_next_as_list_deprecated(self, name=None):
"""Get next element from the underlying iterator."""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next()
return data_list
def get_next_as_list(self, name=None):
"""Get next element from underlying iterator.
If there is no data left, a list of dummy tensors with possible batch
dimensions set to 0 will be returned.
Args:
name: not used.
Returns:
A boolean tensor indicates whether there is any data in next element and
the real data as the next element or a list of dummy tensors if no data
left.
"""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next_as_optional()
result = []
for i, data in enumerate(data_list):
# Place the condition op in the same device as the data so the data
# doesn't need to be sent back to the worker.
with ops.device(self._devices[i]):
# As MultiDeviceIterator will fetch data in order, so we only need to
# check if the first replica has value to see whether there is data
# left for this single worker.
if i == 0:
worker_has_value = data.has_value()
# pylint: disable=unnecessary-lambda
# pylint: disable=cell-var-from-loop
real_data = control_flow_ops.cond(
data.has_value(),
lambda: data.get_value(),
lambda: _dummy_tensor_fn(data.value_structure))
result.append(real_data)
# pylint: enable=cell-var-from-loop
# pylint: enable=unnecessary-lambda
return worker_has_value, result
def initialize(self):
"""Initialze underlying iterator.
In eager execution, this simply recreates the underlying iterator.
In graph execution, it returns the initializer ops for the underlying
iterator.
Returns:
A list of any initializer ops that should be run.
"""
if ops.executing_eagerly_outside_functions():
self._iterator._eager_reset() # pylint: disable=protected-access
return []
else:
return [self._iterator.initializer]
@property
def output_classes(self):
return dataset_ops.get_legacy_output_classes(self._iterator)
@property
def output_shapes(self):
return dataset_ops.get_legacy_output_shapes(self._iterator)
@property
def output_types(self):
return dataset_ops.get_legacy_output_types(self._iterator)
class _SingleWorkerCallableIterator(object):
"""Iterator for a single tensor-returning callable."""
def __init__(self, fn, worker, devices):
self._fn = fn
self._worker = worker
self._devices = devices
def get_next(self, device, name=None):
"""Get next element for the given device from the callable."""
del device, name
with ops.device(self._worker):
return self._fn()
def get_next_as_list_deprecated(self, name=None):
"""Get next element from the callable."""
del name
with ops.device(self._worker):
data_list = [self._fn() for _ in self._devices]
return data_list
def get_next_as_list(self, name=None):
"""Get next element from the callable."""
del name
with ops.device(self._worker):
data_list = [self._fn() for _ in self._devices]
return constant_op.constant(True), data_list
def initialize(self):
# TODO(petebu) Should this throw an exception instead?
return []
def _create_iterators_per_worker(worker_datasets, input_workers):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
iterators.append(iterator)
return iterators
def _create_iterators_per_worker_with_input_context(input_contexts,
input_workers,
dataset_fn):
"""Create a multidevice iterator per workers given a dataset function."""
iterators = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
dataset = dataset_fn(ctx)
# TODO(b/138745411): Remove once stateful transformations are supported.
options = dataset_ops.Options()
options.experimental_distribute._make_stateless = True # pylint: disable=protected-access
dataset = dataset.with_options(options)
devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
iterators.append(iterator)
return iterators
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
"""Get the batched dataset from `d`."""
# pylint: disable=protected-access
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, (dataset_ops.PrefetchDataset,
dataset_ops._OptionsDataset)):
return _get_batched_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
def _get_batched_dataset_attributes(d):
"""Get `batch_size`, `drop_remainder` of dataset."""
# pylint: disable=protected-access
assert isinstance(d,
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
if isinstance(d, dataset_ops.BatchDataset):
batch_size = d._batch_size
drop_remainder = d._drop_remainder
elif isinstance(d, batching._MapAndBatchDataset):
batch_size = d._batch_size_t
drop_remainder = d._drop_remainder_t
# pylint: enable=protected-access
if tensor_util.is_tensor(batch_size):
batch_size = tensor_util.constant_value(batch_size)
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
return batch_size, drop_remainder
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_dataset_attributes(dataset):
"""Get the underlying attributes from the dataset object."""
# pylint: disable=protected-access
# First, get batch_size and drop_remainder from the dataset. We need
# to walk back the dataset creation process and find the batched version in
# order to get the attributes.
batched_dataset = _get_batched_dataset(dataset)
batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
# Second, prefetch buffer should be get from the original dataset.
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
prefetch_buffer = dataset._buffer_size
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
prefetch_buffer = dataset._dataset._buffer_size
return batch_size, drop_remainder, prefetch_buffer
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
function to specify which outputs to emit at what frequency. Currently it
supports capturing output from the last step, as well as capturing non tensor
outputs. In the future it will be augmented to support other use cases such
as output each N steps.
"""
def __init__(self):
"""Initialize an output context.
Returns:
A context object.
"""
self._last_step_outputs = {}
self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {}
@property
def last_step_outputs(self):
"""A dictionary consisting of outputs to be captured on last step.
Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with a `reduce_op` for this output,
then the value is the reduced value.
Returns:
A dictionary with last step outputs.
"""
return self._last_step_outputs
def _set_last_step_outputs(self, outputs):
"""Replace the entire dictionary of last step outputs."""
if not isinstance(outputs, dict):
raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
@property
def non_tensor_outputs(self):
"""A dictionary consisting of any non tensor outputs to be captured."""
return self._non_tensor_outputs
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = (
distribution.experimental_local_results(value))
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))