blob: bd79be2b3521a1434a80f3e24e42c5773cebf917 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrapper for prefetching_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
class _PerDeviceGenerator(dataset_ops.DatasetV2):
"""A `dummy` generator dataset."""
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
source_device, element_spec):
self._element_spec = element_spec
multi_device_iterator_string_handle = (
gen_dataset_ops.multi_device_iterator_to_string_handle(
multi_device_iterator_resource))
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _init_func():
return multi_device_iterator_string_handle
init_func_concrete = _init_func.get_concrete_function()
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(autograph=False) # Pure graph code.
def _remote_init_func():
return functional_ops.remote_call(
target=source_device,
args=init_func_concrete.captured_inputs,
Tout=[dtypes.string],
f=init_func_concrete)
self._init_func = _remote_init_func.get_concrete_function()
self._init_captured_args = self._init_func.captured_inputs
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _next_func(string_handle):
# pylint: disable=protected-access
multi_device_iterator = (
gen_dataset_ops.multi_device_iterator_from_string_handle(
string_handle=string_handle,
output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(
self._element_spec)))
return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
multi_device_iterator=multi_device_iterator,
shard_num=shard_num,
incarnation_id=incarnation_id,
output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
next_func_concrete = _next_func.get_concrete_function()
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun_with_attributes(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
attributes={"experimental_ints_on_device": True},
autograph=False) # Pure graph code.
def _remote_next_func(string_handle):
return functional_ops.remote_call(
target=source_device,
args=[string_handle] + next_func_concrete.captured_inputs,
Tout=structure.get_flat_tensor_types(self._element_spec),
f=next_func_concrete)
self._next_func = _remote_next_func.get_concrete_function()
self._next_captured_args = self._next_func.captured_inputs
self._incarnation_id_index = -1
for i, arg in enumerate(self._next_captured_args):
if arg is incarnation_id:
self._incarnation_id_index = i
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _finalize_func(unused_string_handle):
return array_ops.constant(0, dtypes.int64)
finalize_func_concrete = _finalize_func.get_concrete_function()
# TODO(b/124254153): Enable autograph once the overhead is low enough.
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
autograph=False) # Pure graph code.
def _remote_finalize_func(string_handle):
return functional_ops.remote_call(
target=source_device,
args=[string_handle] + finalize_func_concrete.captured_inputs,
Tout=[dtypes.int64],
f=finalize_func_concrete)
self._finalize_func = _remote_finalize_func.get_concrete_function()
self._finalize_captured_args = self._finalize_func.captured_inputs
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
**self._flat_structure)
super(_PerDeviceGenerator, self).__init__(variant_tensor)
def _inputs(self):
# TODO(b/116506223): Determine which datasets should be used as inputs here.
return []
@property
def element_spec(self):
return self._element_spec
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
"""Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.
Re-uses the functions from the provided per_device_dataset and just switches
out the function argument corresponding to the incarnation_id.
"""
def __init__(self, per_device_dataset, incarnation_id):
# pylint: disable=protected-access
self._element_spec = per_device_dataset.element_spec
self._init_func = per_device_dataset._init_func
self._init_captured_args = self._init_func.captured_inputs
self._next_func = per_device_dataset._next_func
self._next_captured_args = per_device_dataset._next_captured_args
# The captured arguments to the next_func are string_handle, incarnation_id.
# We update the incarnation id to the new one.
self._next_captured_args[
per_device_dataset._incarnation_id_index] = incarnation_id
self._finalize_func = per_device_dataset._finalize_func
self._finalize_captured_args = per_device_dataset._finalize_captured_args
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
**self._flat_structure)
super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
def _inputs(self):
# TODO(b/116506223): Determine which datasets should be used as inputs here.
return []
@property
def element_spec(self):
return self._element_spec
def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size,
experimental_slack):
"""Uses _prototype_device_datasets[i] to build a dataset for the device."""
ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id)
if prefetch_buffer_size > 0:
if experimental_slack:
ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1)
else:
ds = ds.prefetch(prefetch_buffer_size)
# TODO(jsimsa): Enable auto-tuning and optimizations when supported for
# non-CPU devices.
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = False
ds = ds.with_options(options)
return ds
class MultiDeviceIterator(object):
"""An iterator over multiple devices."""
def __init__(self,
dataset,
devices,
max_buffer_size=1,
prefetch_buffer_size=1,
source_device="/cpu:0"):
"""Constructs a MultiDeviceIterator.
Args:
dataset: The input dataset to be iterated over.
devices: The list of devices to fetch data to.
max_buffer_size: Maximum size of the host side per device buffer to keep.
prefetch_buffer_size: if > 1, then we setup a buffer on each device to
prefetch into.
source_device: The host device to place the `dataset` on. In order to
prevent deadlocks, if the prefetch_buffer_size is greater than the
max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
"""
options = dataset_ops.Options()
options.experimental_distribute.num_devices = len(devices)
dataset = dataset.with_options(options)
self._dataset = dataset._apply_options() # pylint: disable=protected-access
self._experimental_slack = dataset.options().experimental_slack
self._devices = devices
self._source_device = source_device
self._source_device_tensor = ops.convert_to_tensor(source_device)
self._max_buffer_size = max_buffer_size
self._prefetch_buffer_size = prefetch_buffer_size
if self._prefetch_buffer_size > self._max_buffer_size:
self._max_buffer_size = self._prefetch_buffer_size
# Create the MultiDeviceIterator.
with ops.device(self._source_device):
# TODO(b/121378567): Get rid of this shared_name hack.
shared_name = ""
if context.executing_eagerly():
shared_name = context.shared_name()
self._multi_device_iterator_resource = (
gen_dataset_ops.multi_device_iterator(
devices=self._devices,
shared_name=shared_name,
container="",
**self._dataset._flat_structure)) # pylint: disable=protected-access
if context.executing_eagerly():
# Delete the resource when this object is deleted
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._multi_device_iterator_resource,
handle_device=self._source_device)
# The incarnation ID is used to ensure consistency between the per-device
# iterators and the multi-device iterator.
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
self._dataset._variant_tensor, # pylint: disable=protected-access
self._multi_device_iterator_resource,
max_buffer_size=self._max_buffer_size)
self._prototype_device_datasets = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
self._incarnation_id,
self._source_device_tensor,
self._dataset.element_spec)
self._prototype_device_datasets.append(ds)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
# initialize the device side of the pipeline. This would allow the
# MultiDeviceIterator to choose, for example, to move some transformations
# into the device side from its input. It might be useful in rewriting.
# Create the per device iterators.
self._device_iterators = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _create_device_dataset(self._prototype_device_datasets[i],
self._incarnation_id,
self._prefetch_buffer_size,
self._experimental_slack)
if context.executing_eagerly():
self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
else:
self._device_iterators.append(
dataset_ops.make_initializable_iterator(ds))
if not context.executing_eagerly():
device_iterator_initializers = [
iterator.initializer for iterator in self._device_iterators
]
self._initializer = control_flow_ops.group(*device_iterator_initializers)
def _create_device_dataset(self, i):
"""Uses _prototype_device_datasets[i] to build a dataset for the device."""
ds = self._prototype_device_datasets[i]
ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id)
if self._prefetch_buffer_size > 0:
if self._experimental_slack:
ds = dataset_ops.PrefetchDataset(
ds, self._prefetch_buffer_size, slack_period=1)
else:
ds = ds.prefetch(self._prefetch_buffer_size)
# TODO(jsimsa): Enable auto-tuning and optimizations when supported for
# non-CPU devices.
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = False
ds = ds.with_options(options)
return ds
def get_next(self, device=None):
"""Returns the next element given a `device`, else returns all in a list."""
if device is not None:
index = self._devices.index(device)
return self._device_iterators[index].get_next()
result = []
for i, device in enumerate(self._devices):
with ops.device(device):
result.append(self._device_iterators[i].get_next())
return result
def get_next_as_optional(self):
result = []
for i, device in enumerate(self._devices):
with ops.device(device):
result.append(
iterator_ops.get_next_as_optional(self._device_iterators[i]))
return result
@property
def initializer(self):
if context.executing_eagerly():
return control_flow_ops.no_op()
return self._initializer
def _eager_reset(self):
"""Resets the MultiDeviceIterator in eager mode."""
if not ops.executing_eagerly_outside_functions():
raise ValueError("Eager reset is only supported in eager mode.")
# pylint: disable=protected-access
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
self._dataset._variant_tensor,
self._multi_device_iterator_resource,
max_buffer_size=self._max_buffer_size)
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _create_device_dataset(self._prototype_device_datasets[i],
self._incarnation_id,
self._prefetch_buffer_size,
self._experimental_slack)
# Reset the device iterator resources with the new dataset.
ds_variant = ds._variant_tensor
gen_dataset_ops.make_iterator(
ds_variant, self._device_iterators[i]._iterator_resource)
@property
def element_spec(self):
return self._dataset.element_spec
class MultiDeviceIteratorResourceDeleter(object):
"""An object which cleans up a Multi Device Iterator resource.
An alternative to defining a __del__ method on an object. Even if the parent
object is part of a reference cycle, the cycle will be collectible.
"""
def __init__(self, multi_device_iterator, iterators, device, deleter):
self._deleter = deleter
self._multi_device_iterator = multi_device_iterator
self._iterators = iterators
self._device = device
self._eager_mode = context.executing_eagerly()
def __del__(self):
with ops.device(self._device):
# Make sure the resource is deleted in the same mode as it was created in.
# We pass in the iterator handles as inputs to the op to make sure that
# this op runs after all the iterators are deleted.
if self._eager_mode:
with context.eager_mode():
gen_dataset_ops.delete_multi_device_iterator(
multi_device_iterator=self._multi_device_iterator,
iterators=self._iterators,
deleter=self._deleter)
else:
with context.graph_mode():
gen_dataset_ops.delete_multi_device_iterator(
multi_device_iterator=self._multi_device_iterator,
iterators=self._iterators,
deleter=self._deleter)
class MultiDeviceIteratorSpec(type_spec.TypeSpec):
"""Type specification for `OwnedMultiDeviceIterator`."""
__slots__ = ["_devices", "_source_device", "_element_spec"]
def __init__(self, devices, source_device, element_spec):
self._devices = devices
self._source_device = source_device
self._element_spec = element_spec
@property
def value_type(self):
return OwnedMultiDeviceIterator
def _serialize(self):
return (tuple(self._devices), self._source_device, self._element_spec)
@property
def _component_specs(self):
specs = [
tensor_spec.TensorSpec([], dtypes.resource),
tensor_spec.TensorSpec([], dtypes.variant)
]
for _ in range(len(self._devices)):
specs.append(iterator_ops.IteratorSpec(self._element_spec))
return specs
def _to_components(self, value):
# pylint: disable=protected-access
c = [value._multi_device_iterator_resource, value._deleter]
c.extend(value._device_iterators)
return c
def _from_components(self, components):
return OwnedMultiDeviceIterator(
dataset=None,
devices=self._devices,
source_device=self._source_device,
components=components,
element_spec=self._element_spec)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return MultiDeviceIteratorSpec(
value._devices,
value._source_device,
value.element_spec)
class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
"""An iterator over multiple devices.
The multi-device iterator resource created through `OwnedMultiDeviceIterator`
is owned by the Python object and the life time of the underlying resource is
tied to the life time of the `OwnedMultiDeviceIterator` object. This makes
`OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of
tf.functions.
"""
def __init__(self,
dataset=None,
devices=None,
max_buffer_size=1,
prefetch_buffer_size=1,
source_device="/cpu:0",
components=None,
element_spec=None):
"""Constructs an owned MultiDeviceIterator object.
Args:
dataset: The input dataset to be iterated over.
devices: The list of devices to fetch data to.
max_buffer_size: Maximum size of the host side per device buffer to keep.
prefetch_buffer_size: if > 1, then we setup a buffer on each device to
prefetch into.
source_device: The host device to place the `dataset` on. In order to
prevent deadlocks, if the prefetch_buffer_size is greater than the
max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
components: Tensor components to construct the MultiDeviceIterator from.
element_spec: A nested structure of `TypeSpec` objects that
represents the type specification of elements of the iterator.
Raises:
RuntimeError: If executed in graph mode or outside of function building
mode.
"""
if not context.executing_eagerly() and not ops.inside_function():
raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
"tf.function or when eager execution is enabled.")
if devices is None:
raise ValueError("`devices` must be provided")
error_message = "Either `dataset` or both `components` and "
"`element_spec` need to be provided."
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._devices = devices
self._source_device = source_device
self._multi_device_iterator_resource = components[0]
self._deleter = components[1]
self._device_iterators = components[2:]
iterator_handles = []
for it in self._device_iterators:
iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
options = dataset_ops.Options()
options.experimental_distribute.num_devices = len(devices)
dataset = dataset.with_options(options)
dataset = dataset._apply_options() # pylint: disable=protected-access
self._element_spec = dataset.element_spec
experimental_slack = dataset.options().experimental_slack
self._devices = devices
self._source_device = source_device
source_device_tensor = ops.convert_to_tensor(self._source_device)
if prefetch_buffer_size > max_buffer_size:
max_buffer_size = prefetch_buffer_size
# Create the MultiDeviceIterator.
with ops.device(self._source_device):
self._multi_device_iterator_resource, self._deleter = (
gen_dataset_ops.anonymous_multi_device_iterator(
devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access
# The incarnation ID is used to ensure consistency between the
# per-device iterators and the multi-device iterator.
incarnation_id = gen_dataset_ops.multi_device_iterator_init(
dataset._variant_tensor, # pylint: disable=protected-access
self._multi_device_iterator_resource,
max_buffer_size=max_buffer_size)
prototype_device_datasets = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
incarnation_id, source_device_tensor,
dataset.element_spec)
prototype_device_datasets.append(ds)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
# initialize the device side of the pipeline. This would allow the
# MultiDeviceIterator to choose, for example, to move some transformations
# into the device side from its input. It might be useful in rewriting.
# Create the per device iterators.
self._device_iterators = []
iterator_handles = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _create_device_dataset(prototype_device_datasets[i],
incarnation_id, prefetch_buffer_size,
experimental_slack)
iterator = iter(ds)
self._device_iterators.append(iterator)
iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access
self._resource_deleter = MultiDeviceIteratorResourceDeleter(
multi_device_iterator=self._multi_device_iterator_resource,
iterators=iterator_handles,
device=self._source_device,
deleter=self._deleter)
def get_next(self, device=None):
"""Returns the next element given a `device`, else returns all in a list."""
if device is not None:
index = self._devices.index(device)
return self._device_iterators[index].get_next()
result = []
for i, device in enumerate(self._devices):
with ops.device(device):
result.append(self._device_iterators[i].get_next())
return result
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
try:
return self.get_next()
except errors.OutOfRangeError:
raise StopIteration
def get_next_as_optional(self):
result = []
for i, device in enumerate(self._devices):
with ops.device(device):
result.append(
iterator_ops.get_next_as_optional(self._device_iterators[i]))
return result
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return MultiDeviceIteratorSpec(self._devices, self._source_device,
self._element_spec)