blob: 9f8d14706845baa1ed45c84b2c15d372915a0eb4 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""A RunConfig subclass with TPU support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import tf_logging as logging
# pylint: disable=protected-access
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
# pylint: enable=protected-access
class InputPipelineConfig(object):
r"""Please see the definition of these values in TPUConfig."""
PER_SHARD_V1 = 1
PER_HOST_V1 = 2
PER_HOST_V2 = 3
BROADCAST = 4
class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
'input_partition_dims',
])):
r"""TPU related configuration required by `TPUEstimator`.
Args:
iterations_per_loop: This is the number of train steps running in TPU
system before returning to CPU host for each `Session.run`. This means
global step is increased `iterations_per_loop` times in one `Session.run`.
It is recommended to be set as number of global steps for next checkpoint.
num_shards: (Deprecated, ignored by TPUEstimator).
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
the model to multiple cores. Currently num_cores_per_replica must be
1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked once on each host. With the per-core input pipeline
configuration, it is invoked once for each core.
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
the batch size for each shard is `train_batch_size` // #hosts in the
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
invoked once on host 0 and the tensors are broadcasted to all other
replicas. The batch size equals to train_batch_size`. With the per-core
input pipeline configuration, the shard batch size is also
`train_batch_size` // #cores.
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
string.
initial_infeed_sleep_secs: The number of seconds the infeed thread should
wait before enqueueing the first batch. This helps avoid timeouts for
models that require a long compilation time.
input_partition_dims: A nested list to describe the partition dims
for all the tensors from input_fn(). The structure of
input_partition_dims must match the structure of `features` and
`labels` from input_fn(). The total number of partitions must match
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
images with shape [N, H, W, C] and labels [N].
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
to all the TPU cores since the partition dims is `None`.
Current limitations: This feature is only supported with the PER_HOST_V2
input mode.
Raises:
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None,
input_partition_dims=None):
# Check iterations_per_loop.
util_lib.check_positive_integer(iterations_per_loop,
'TPUConfig iterations_per_loop')
# Check num_shards.
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
if input_partition_dims is not None:
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
raise ValueError(
'input_partition_dims must be a list/tuple with one or two'
' elements.')
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
raise ValueError(
'input_partition_dims is only supported in PER_HOST_V2 mode.')
if num_cores_per_replica is None:
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
# Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
if per_host_input_for_training is False:
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
elif per_host_input_for_training is True:
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
'TPUConfig initial_infeed_sleep_secs')
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
input_partition_dims=input_partition_dims)
class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
def __init__(self,
tpu_config=None,
evaluation_master=None,
master=None,
cluster=None,
**kwargs):
"""Constructs a RunConfig.
Args:
tpu_config: the TPUConfig that specifies TPU-specific configuration.
evaluation_master: a string. The address of the master to use for eval.
Defaults to master if not set.
master: a string. The address of the master to use for training.
cluster: a ClusterResolver
**kwargs: keyword config parameters.
Raises:
ValueError: if cluster is not None and the provided session_config has a
cluster_def already.
"""
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TPUConfig()
self._cluster = cluster
# If user sets master and/or evaluation_master explicitly, including empty
# string '', take it. Otherwise, take the values set by parent class.
if master is not None:
if cluster is not None:
raise ValueError('Both master and cluster are set.')
self._master = master
else:
if cluster:
self._master = cluster.master()
if evaluation_master is not None:
self._evaluation_master = evaluation_master
elif (not self._evaluation_master and
self.task_type != run_config_lib.TaskType.EVALUATOR):
# If the task type is EVALUATOR, it means some cluster manager sets the
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
#
# Otherwise, it means user executes the code without external cluster
# manager. For that, we optimize the user experience by setting
# evaluation_master to master, unless user overwrites it.
self._evaluation_master = self._master
# Set the ClusterSpec to use
if cluster:
self._cluster_spec = cluster.cluster_spec()
# Merge the cluster_def into the ConfigProto.
if self._session_config is None: # pylint: disable=access-member-before-definition
self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
if self._session_config.HasField('cluster_def'):
raise ValueError(
'You cannot provide a ClusterResolver and '
'session_config.cluster_def.')
if self._cluster_spec:
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
def _maybe_overwrite_session_config_for_distributed_training(self):
# Overrides the parent class session_config overwrite for between-graph. TPU
# runs with in-graph, which should not have device filter. Doing nothing
# ("pass") basically disables it.
pass
@property
def evaluation_master(self):
return self._evaluation_master
@property
def master(self):
return self._master
@property
def tpu_config(self):
return self._tpu_config
@property
def cluster(self):
return self._cluster
def replace(self, **kwargs):
if 'tpu_config' not in kwargs:
return super(RunConfig, self).replace(**kwargs)
tpu_config = kwargs.pop('tpu_config')
new_instance = super(RunConfig, self).replace(**kwargs)
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
return new_instance
def _get_tpu_job_name_from_tf_config():
"""Extracts the TPU job name from TF_CONFIG env variable."""
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
# spec propagation.
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
if tpu_job_name:
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
return tpu_job_name