blob: bc9cf287d149b5001f263642854ca8c74fefa4e5 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Python interface for creating dataset servers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
from tensorflow.core.protobuf.data.experimental import service_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.DispatcherConfig")
class DispatcherConfig(
collections.namedtuple(
"DispatcherConfig",
["port", "protocol", "work_dir", "fault_tolerant_mode"])):
"""Configuration class for tf.data service dispatchers.
Fields:
port: Specifies the port to bind to. A value of 0 indicates that the server
may bind to any available port.
protocol: The protocol to use for communicating with the tf.data service.
Acceptable values include `"grpc" and "grpc+local"`.
work_dir: A directory to store dispatcher state in. This
argument is required for the dispatcher to be able to recover from
restarts.
fault_tolerant_mode: Whether the dispatcher should write its state to a
journal so that it can recover from restarts. Dispatcher state, including
registered datasets and created jobs, is synchronously written to the
journal before responding to RPCs. If `True`, `work_dir` must also be
specified.
"""
def __new__(cls,
port=0,
protocol="grpc",
work_dir=None,
fault_tolerant_mode=False):
return super(DispatcherConfig, cls).__new__(cls, port, protocol, work_dir,
fault_tolerant_mode)
@tf_export("data.experimental.service.DispatchServer", v1=[])
class DispatchServer(object):
"""An in-process tf.data service dispatch server.
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the dispatcher.
>>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data dispatch process, use join() to block
indefinitely after starting up the server.
```
dispatcher = tf.data.experimental.service.DispatchServer(
tf.data.experimental.service.DispatcherConfig(port=5050))
dispatcher.join()
```
To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
`fault_tolerant_mode` like below:
```
dispatcher = tf.data.experimental.service.DispatchServer(
tf.data.experimental.service.DispatcherConfig(
port=5050,
work_dir="gs://my-bucket/dispatcher/work_dir",
fault_tolerant_mode=True))
```
"""
def __init__(self, config=None, start=True):
"""Creates a new dispatch server.
Args:
config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
configration. If `None`, the dispatcher will use default
configuration values.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to True.
"""
config = config or DispatcherConfig()
if config.fault_tolerant_mode and not config.work_dir:
raise ValueError(
"Cannot enable fault tolerant mode without configuring a work_dir")
self._config = config
config_proto = service_config_pb2.DispatcherConfig(
port=config.port,
protocol=config.protocol,
work_dir=config.work_dir,
fault_tolerant_mode=config.fault_tolerant_mode)
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
config_proto.SerializeToString())
if start:
self._server.start()
def start(self):
"""Starts this server.
>>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
>>> dispatcher.start()
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
starting the server.
"""
self._server.start()
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated dispatch process.
```
dispatcher = tf.data.experimental.service.DispatchServer(
tf.data.experimental.service.DispatcherConfig(port=5050))
dispatcher.join()
```
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
joining the server.
"""
self._server.join()
@property
def target(self):
"""Returns a target that can be used to connect to the server.
>>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050".
"""
return "{0}://localhost:{1}".format(self._config.protocol,
self._server.bound_port())
def _stop(self):
"""Stops the server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
stopping the server.
"""
self._server.stop()
def __del__(self):
self._stop()
@property
def _address(self):
"""Returns the address of the server.
The returned string will be in the form address:port, e.g. "localhost:1000".
"""
return "localhost:{0}".format(self._server.bound_port())
def _num_workers(self):
"""Returns the number of workers registered with the dispatcher."""
return self._server.num_workers()
@tf_export("data.experimental.service.WorkerConfig")
class WorkerConfig(
collections.namedtuple(
"WorkerConfig",
["dispatcher_address", "worker_address", "port", "protocol"])):
"""Configuration class for tf.data service workers.
Fields:
dispatcher_address: Specifies the address of the dispatcher.
worker_address: Specifies the address of the worker server. This address is
passed to the dispatcher so that the dispatcher can tell clients how to
connect to this worker.
port: Specifies the port to bind to. A value of 0 indicates that the worker
can bind to any available port.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc" and "grpc+local"`.
"""
def __new__(cls,
dispatcher_address,
worker_address=None,
port=0,
protocol="grpc"):
worker_address = ("localhost:%port%"
if worker_address is None else worker_address)
return super(WorkerConfig, cls).__new__(cls, dispatcher_address,
worker_address, port, protocol)
@tf_export("data.experimental.service.WorkerServer", v1=[])
class WorkerServer(object):
"""An in-process tf.data service worker server.
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
processing for user-defined datasets, and provides the resulting elements over
RPC. A worker is associated with a single
`tf.data.experimental.service.DispatchServer`.
>>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data worker process, use join() to block
indefinitely after starting up the server.
```
worker = tf.data.experimental.service.WorkerServer(
port=5051, dispatcher_address="grpc://localhost:5050")
worker.join()
```
"""
def __init__(self, config, start=True):
"""Creates a new worker server.
Args:
config: A `tf.data.experimental.service.WorkerConfig` configration.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to True.
"""
if config.dispatcher_address is None:
raise ValueError("must specify a dispatcher_address")
config_proto = service_config_pb2.WorkerConfig(
dispatcher_address=config.dispatcher_address,
worker_address=config.worker_address,
port=config.port,
protocol=config.protocol)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config_proto.SerializeToString())
if start:
self._server.start()
def start(self):
"""Starts this server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
starting the server.
"""
self._server.start()
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated worker process.
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5051, dispatcher_address="grpc://localhost:5050")
worker_server.join()
```
This method currently blocks forever.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
joining the server.
"""
self._server.join()
def _stop(self):
"""Stops the server.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
stopping the server.
"""
self._server.stop()
def __del__(self):
self._stop()
@property
def _address(self):
"""Returns the address of the server.
The returned string will be in the form address:port, e.g. "localhost:1000".
"""
return "localhost:{0}".format(self._server.bound_port())