blob: d16e9966b23f760e6e1edd81ec62e17f6711aaef [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python API for executing a tf.data.Dataset using a tf.data service."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import six
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import compression_ops
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
class ProcessingMode(object):
PARALLEL_EPOCHS = "parallel_epochs"
@staticmethod
def validate(mode):
"""Raises a ValueError if the given object is not a valid processing mode."""
valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
if mode not in valid_modes:
raise ValueError(
"{0} is not a valid processing mode. Valid modes: {1}".format(
mode, valid_modes))
class _DataServiceDatasetV2(dataset_ops.DatasetSource):
"""A `Dataset` that reads elements from the tf.data service."""
def __init__(self,
input_dataset,
dataset_id,
processing_mode,
address,
protocol,
job_name=None,
max_outstanding_requests=None,
task_refresh_interval_hint_ms=None):
"""Constructs a _DataServiceDatasetV2.
Args:
input_dataset: The input dataset, which should be registered with the
tf.data service under `dataset_id`.
dataset_id: The dataset id for the dataset to read from.
processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is
"parallel_epochs".
address: The tf.data service address, e.g. "localhost:5000".
protocol: The protocol to use for communicating with the tf.data service,
e.g. "grpc".
job_name: (Optional.) The name of the job. This argument makes it
possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the
amount of memory used, since `distribute` won't use more than
`element_size` * `max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
the master for task changes.
"""
if job_name is None:
job_name = ""
if max_outstanding_requests is None:
max_outstanding_requests = dataset_ops.AUTOTUNE
if task_refresh_interval_hint_ms is None:
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
self._input_dataset = input_dataset
self._dataset_id = ops.convert_to_tensor(
dataset_id, dtype=dtypes.int64, name="dataset_id")
self._processing_mode = ops.convert_to_tensor(
processing_mode, dtype=dtypes.string, name="processing_mode")
self._address = ops.convert_to_tensor(
address, dtype=dtypes.string, name="address")
self._protocol = ops.convert_to_tensor(
protocol, dtype=dtypes.string, name="protocol")
self._job_name = ops.convert_to_tensor(
job_name, dtype=dtypes.string, name="job_name")
self._max_outstanding_requests = ops.convert_to_tensor(
max_outstanding_requests,
dtype=dtypes.int64,
name="max_outstanding_requests")
self._element_spec = input_dataset.element_spec
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
dataset_id=self._dataset_id,
processing_mode=self._processing_mode,
address=self._address,
protocol=self._protocol,
job_name=self._job_name,
max_outstanding_requests=self._max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
),
**self._flat_structure)
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
@property
def element_spec(self):
return self._element_spec
class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
"""A `Dataset` that executes its input through the tf.data service."""
@functools.wraps(_DataServiceDatasetV2.__init__)
def __init__(self, input_dataset, dataset_id, processing_mode, address,
protocol, job_name, max_outstanding_requests,
task_refresh_interval_hint_ms):
self._wrapped = _DataServiceDatasetV2(
input_dataset=input_dataset,
dataset_id=dataset_id,
processing_mode=processing_mode,
address=address,
protocol=protocol,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
if tf2.enabled():
_DataServiceDataset = _DataServiceDatasetV2
else:
_DataServiceDataset = _DataServiceDatasetV1
def _distribute(processing_mode,
service,
job_name=None,
max_outstanding_requests=None,
task_refresh_interval_hint_ms=None):
"""A transformation that moves dataset processing to the tf.data service.
This transformation is similar to `distribute`, but supports additional
parameters which we do not yet want to add to the public Python API.
Args:
processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is
"parallel_epochs".
service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000.
job_name: (Optional.) The name of the job. This argument makes it
possible for multiple datasets to share the same job. The default behavior
is that the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
master for task changes.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
"""
ProcessingMode.validate(processing_mode)
if job_name is not None:
if not isinstance(job_name, six.string_types):
raise ValueError("job_name must be a string, but job_name was of type "
"{0}. job_name={1}".format(type(job_name), job_name))
if not job_name:
raise ValueError("job_name must not be empty")
if not isinstance(service, six.string_types):
raise ValueError(
"service must be a string, but service was of type {0}. service={1}"
.format(type(service), service))
if not service:
raise ValueError("service must not be empty")
parts = service.split("://")
if len(parts) == 1:
raise ValueError("service string %s does not begin with a protocol. "
"The service should be in the format "
"<protocol>://<address>, e.g. grpc://localhost:5000" %
service)
if len(parts) > 2:
raise ValueError("malformed service string has multiple '://': %s" %
service)
protocol, address = parts
address = ops.convert_to_tensor(address, dtype=dtypes.string, name="address")
protocol = ops.convert_to_tensor(
protocol, dtype=dtypes.string, name="protocol")
def _apply_fn(dataset): # pylint: disable=missing-docstring
external_state_policy = dataset.options().experimental_external_state_policy
if external_state_policy is None:
external_state_policy = ExternalStatePolicy.WARN
uncompressed_spec = dataset.element_spec
# Compress the dataset elements to reduce the amount of data that needs to
# be sent over the network.
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
# to limit memory usage.
dataset = dataset.map(lambda *x: compression_ops.compress(x))
# Prefetch one compressed element to reduce latency when requesting data
# from tf.data workers.
# TODO(b/157105111): Set this to autotune when we have a way to limit
# memory usage
dataset = dataset.prefetch(1)
# Apply options so that the dataset executed in the tf.data service will
# be optimized and support autotuning.
dataset = dataset._apply_options() # pylint: disable=protected-access
dataset_id = gen_experimental_dataset_ops.register_dataset(
dataset._variant_tensor, # pylint: disable=protected-access
address=address,
protocol=protocol,
external_state_policy=external_state_policy.value)
dataset = _DataServiceDataset(
input_dataset=dataset,
dataset_id=dataset_id,
processing_mode=processing_mode,
address=address,
protocol=protocol,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
# to limit memory usage.
# The value 16 is chosen based on experience with pipelines that require
# more than 8 parallel calls to prevent this stage from being a bottleneck.
dataset = dataset.map(
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec),
num_parallel_calls=16)
# Disable autosharding for shared jobs.
if job_name:
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
return _apply_fn
@tf_export("data.experimental.service.distribute")
def distribute(processing_mode,
service,
job_name=None,
max_outstanding_requests=None):
"""A transformation that moves dataset processing to the tf.data service.
When you iterate over a dataset containing the `distribute` transformation,
the tf.data service creates a "job" which produces data for the dataset
iteration.
The `processing_mode` argument controls what data is produced by a tf.data
service job. Currently, the only supported mode is "parallel_epochs".
processing_mode="parallel_epochs" means that multiple tf.data workers will
iterate through the dataset in parallel, each producing all elements of the
dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker
used for execution will produce {0, 1, 2}. If there are 3 workers, the job
will produce the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} (though not necessarily
in that order). To account for this, it is recommended to randomly shuffle
your dataset, so that different tf.data workers will iterate through the
dataset in different orders.
In the future, there will be additional processing modes. For example,
a "one_epoch" mode which partitions the dataset across the tf.data
workers, so that the consumers see each element of the dataset only once.
```
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x*x)
dataset = dataset.apply(
tf.data.experimental.service.distribute("parallel_epochs",
"grpc://dataservice:5000"))
dataset = dataset.map(lambda x: x+1)
for element in dataset:
print(element) # prints { 1, 2, 5, 10, 17 }
```
In the above example, the first two lines (before the call to `distribute`)
will be executed on tf.data workers, and the elements provided over
RPC. The remaining transformations (after the call to `distribute`) will be
executed locally.
The `job_name` argument allows jobs to be shared across multiple
datasets. Instead of each dataset creating its own job, all
datasets with the same `job_name` will consume from the same job. A new job
will be created for each iteration of the dataset (with each repetition of
`Dataset.repeat` counting as a new iteration). Suppose two training workers
(in either a single client or multi-client setup) iterate over the below
dataset, and there is a single tf.data worker:
```
range5_dataset = tf.data.Dataset.range(5)
dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
"parallel_epochs", "grpc://dataservice:5000", job_name="my_job_name"))
for iteration in range(3):
print(list(dataset))
```
The elements of each job will be split between the two processes, with
elements being consumed by the processes on a first-come first-served basis.
One possible result is that process 1 prints
```
[0, 2, 4]
[0, 1, 3]
[1]
```
and process 2 prints
```
[1, 3]
[2, 4]
[0, 2, 3, 4]
```
Job names must not be re-used across different training jobs within the
lifetime of the tf.data service. In general, the tf.data service is expected
to live for the duration of a single training job.
To use the tf.data service with multiple training jobs, make sure to use
different job names to avoid conflicts. For example, suppose a training job
calls `distribute` with `job_name="job"` and reads until end of input. If
another independent job connects to the same tf.data service and tries to read
from `job_name="job"`, it will immediately receive end of input, without
getting any data.
**Keras and Distribution Strategies**
The dataset produced by the `distribute` transformation can be passed to
Keras' `Model.fit` or Distribution Strategy's
`tf.distribute.Strategy.experimental_distribute_dataset` like any other
`tf.data.Dataset`. We recommend setting a `job_name` on the call to
`distribute` so that if there are multiple workers, they read data from the
same job. Note that the autosharding normally performed by
`experimental_distribute_dataset` will be disabled when setting a `job_name`,
since sharing the job already results in splitting data across the workers.
When using a shared job, data will be dynamically balanced across workers, so
that they reach end of input about the same time. This results in better
worker utilization than with autosharding, where each worker processes an
independent set of files, and some workers may run out of data earlier than
others.
Args:
processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is
"parallel_epochs".
service: A string indicating how to connect to the tf.data service. The
string should be in the format protocol://address, e.g.
grpc://localhost:5000.
job_name: (Optional.) The name of the job. This argument makes it possible
for multiple datasets to share the same job. The default behavior is that
the dataset creates anonymous, exclusively owned jobs.
max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
"""
return _distribute(
processing_mode=processing_mode,
service=service,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests)