| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Class implementing a multi-worker parameter server tf.distribute strategy.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import copy |
| |
| |
| from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import input_lib |
| from tensorflow.python.distribute import mirrored_run |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.distribute import numpy_dataset |
| from tensorflow.python.distribute import ps_values |
| from tensorflow.python.distribute import values |
| from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver |
| from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import device as tf_device |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope as vs |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import device_setter |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import tf_export |
| |
| _LOCAL_CPU = "/device:CPU:0" |
| |
| |
| @tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring |
| class ParameterServerStrategyV1(distribute_lib.StrategyV1): |
| """An asynchronous multi-worker parameter server tf.distribute strategy. |
| |
| This strategy requires two roles: workers and parameter servers. Variables and |
| updates to those variables will be assigned to parameter servers and other |
| operations are assigned to workers. |
| |
| When each worker has more than one GPU, operations will be replicated on all |
| GPUs. Even though operations may be replicated, variables are not and each |
| worker shares a common view for which parameter server a variable is assigned |
| to. |
| |
| By default it uses `TFConfigClusterResolver` to detect configurations for |
| multi-worker training. This requires a 'TF_CONFIG' environment variable and |
| the 'TF_CONFIG' must have a cluster spec. |
| |
| This class assumes each worker is running the same code independently, but |
| parameter servers are running a standard server. This means that while each |
| worker will synchronously compute a single gradient update across all GPUs, |
| updates between workers proceed asynchronously. Operations that occur only on |
| the first replica (such as incrementing the global step), will occur on the |
| first replica *of every worker*. |
| |
| It is expected to call `call_for_each_replica(fn, ...)` for any |
| operations which potentially can be replicated across replicas (i.e. multiple |
| GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra |
| caution needs to be taken: |
| |
| 1) It is generally not recommended to open a device scope under the strategy's |
| scope. A device scope (i.e. calling `tf.device`) will be merged with or |
| override the device for operations but will not change the device for |
| variables. |
| |
| 2) It is also not recommended to open a colocation scope (i.e. calling |
| `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating |
| variables, use `strategy.extended.colocate_vars_with` instead. Colocation of |
| ops will possibly create device assignment conflicts. |
| |
| Note: This strategy only works with the Estimator API. Pass an instance of |
| this strategy to the `experimental_distribute` argument when you create the |
| `RunConfig`. This instance of `RunConfig` should then be passed to the |
| `Estimator` instance on which `train_and_evaluate` is called. |
| |
| For Example: |
| ``` |
| strategy = tf.distribute.experimental.ParameterServerStrategy() |
| run_config = tf.estimator.RunConfig( |
| experimental_distribute.train_distribute=strategy) |
| estimator = tf.estimator.Estimator(config=run_config) |
| tf.estimator.train_and_evaluate(estimator,...) |
| ``` |
| """ |
| |
| def __init__(self, cluster_resolver=None): |
| """Initializes this strategy with an optional `cluster_resolver`. |
| |
| Args: |
| cluster_resolver: Optional |
| `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a |
| `tf.distribute.cluster_resolver.TFConfigClusterResolver`. |
| """ |
| if cluster_resolver is None: |
| cluster_resolver = TFConfigClusterResolver() |
| super(ParameterServerStrategyV1, self).__init__( |
| ParameterServerStrategyExtended( |
| self, cluster_resolver=cluster_resolver)) |
| distribute_lib.distribution_strategy_gauge.get_cell("V1").set( |
| "ParameterServerStrategy") |
| |
| def experimental_distribute_dataset(self, dataset, options=None): |
| if (options and options.experimental_replication_mode == |
| distribute_lib.InputReplicationMode.PER_REPLICA): |
| raise NotImplementedError( |
| "InputReplicationMode.PER_REPLICA " |
| "is only supported in " |
| "`experimental_distribute_datasets_from_function`." |
| ) |
| self._raise_pss_error_if_eager() |
| super(ParameterServerStrategyV1, |
| self).experimental_distribute_dataset(dataset=dataset, |
| options=options) |
| |
| def distribute_datasets_from_function(self, dataset_fn, options=None): |
| if (options and options.experimental_replication_mode == |
| distribute_lib.InputReplicationMode.PER_REPLICA): |
| raise NotImplementedError( |
| "InputReplicationMode.PER_REPLICA " |
| "is only supported in " |
| "`experimental_distribute_datasets_from_function` " |
| "of tf.distribute.MirroredStrategy") |
| self._raise_pss_error_if_eager() |
| super(ParameterServerStrategyV1, self).distribute_datasets_from_function( |
| dataset_fn=dataset_fn, options=options) |
| |
| def run(self, fn, args=(), kwargs=None, options=None): |
| self._raise_pss_error_if_eager() |
| super(ParameterServerStrategyV1, self).run( |
| fn, args=args, kwargs=kwargs, options=options) |
| |
| def scope(self): |
| self._raise_pss_error_if_eager() |
| return super(ParameterServerStrategyV1, self).scope() |
| |
| def _raise_pss_error_if_eager(self): |
| if context.executing_eagerly(): |
| raise NotImplementedError( |
| "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` " |
| "currently only works with the tf.Estimator API") |
| |
| |
| # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. |
| class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): |
| """Implementation of ParameterServerStrategy and CentralStorageStrategy.""" |
| |
| def __init__(self, |
| container_strategy, |
| cluster_resolver=None, |
| compute_devices=None, |
| parameter_device=None): |
| super(ParameterServerStrategyExtended, self).__init__(container_strategy) |
| self._initialize_strategy( |
| cluster_resolver=cluster_resolver, |
| compute_devices=compute_devices, |
| parameter_device=parameter_device) |
| |
| # We typically don't need to do all-reduce in this strategy. |
| self._cross_device_ops = ( |
| cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) |
| |
| def _initialize_strategy(self, |
| cluster_resolver=None, |
| compute_devices=None, |
| parameter_device=None): |
| if cluster_resolver and cluster_resolver.cluster_spec(): |
| self._initialize_multi_worker(cluster_resolver) |
| else: |
| self._initialize_local( |
| compute_devices, parameter_device, cluster_resolver=cluster_resolver) |
| |
| def _initialize_multi_worker(self, cluster_resolver): |
| """Initialize devices for multiple workers. |
| |
| It creates variable devices and compute devices. Variables and operations |
| will be assigned to them respectively. We have one compute device per |
| replica. The variable device is a device function or device string. The |
| default variable device assigns variables to parameter servers in a |
| round-robin fashion. |
| |
| Args: |
| cluster_resolver: a descendant of `ClusterResolver` object. |
| |
| Raises: |
| ValueError: if the cluster doesn't have ps jobs. |
| """ |
| # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in |
| # some cases. |
| if isinstance(cluster_resolver, TFConfigClusterResolver): |
| num_gpus = context.num_gpus() |
| else: |
| num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) |
| |
| # Save the num_gpus_per_worker for configure method. |
| self._num_gpus_per_worker = num_gpus |
| |
| cluster_spec = cluster_resolver.cluster_spec() |
| task_type = cluster_resolver.task_type |
| task_id = cluster_resolver.task_id |
| if not task_type or task_id is None: |
| raise ValueError("When `cluster_spec` is given, you must also specify " |
| "`task_type` and `task_id`") |
| cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) |
| assert cluster_spec.as_dict() |
| |
| self._worker_device = "/job:%s/task:%d" % (task_type, task_id) |
| self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) |
| |
| # Define compute devices which is a list of device strings and one for each |
| # replica. When there are GPUs, replicate operations on these GPUs. |
| # Otherwise, place operations on CPU. |
| if num_gpus > 0: |
| compute_devices = tuple( |
| "%s/device:GPU:%d" % (self._worker_device, i) |
| for i in range(num_gpus)) |
| else: |
| compute_devices = (self._worker_device,) |
| |
| self._compute_devices = [ |
| device_util.canonicalize(d) for d in compute_devices] |
| |
| # In distributed mode, place variables on ps jobs in a round-robin fashion. |
| # Note that devices returned from `replica_device_setter` are not |
| # canonical and therefore we don't canonicalize all variable devices to |
| # make them consistent. |
| # TODO(yuefengz): support passing a strategy object to control variable |
| # assignment. |
| # TODO(yuefengz): merge the logic of replica_device_setter into this |
| # class. |
| num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) |
| if num_ps_replicas == 0: |
| raise ValueError("The cluster spec needs to have `ps` jobs.") |
| self._variable_device = device_setter.replica_device_setter( |
| ps_tasks=num_ps_replicas, |
| worker_device=self._worker_device, |
| merge_devices=True, |
| cluster=cluster_spec) |
| |
| # The `_parameter_devices` is needed for the `parameter_devices` property |
| # and is a list of all variable devices. Here parameter devices are all |
| # tasks of the "ps" job. |
| self._parameter_devices = tuple(map("/job:ps/task:{}".format, |
| range(num_ps_replicas))) |
| |
| # Add a default device so that ops without specified devices will not end up |
| # on other workers. |
| self._default_device = self._worker_device |
| |
| self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, |
| task_id) |
| self._cluster_spec = cluster_spec |
| self._task_type = task_type |
| self._task_id = task_id |
| |
| logging.info( |
| "Multi-worker ParameterServerStrategy with " |
| "cluster_spec = %r, task_type = %r, task_id = %r, " |
| "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " |
| "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, |
| num_ps_replicas, self._is_chief, self._compute_devices, |
| self._variable_device) |
| |
| # TODO(yuefengz): get rid of cluster_resolver argument when contrib's |
| # version no longer depends on this class. |
| def _initialize_local(self, |
| compute_devices, |
| parameter_device, |
| cluster_resolver=None): |
| """Initialize local devices for training.""" |
| self._worker_device = device_util.canonicalize("/device:CPU:0") |
| self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) |
| |
| if compute_devices is None: |
| if not cluster_resolver: |
| num_gpus = context.num_gpus() |
| else: |
| num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) |
| # Save the num_gpus_per_worker for configure method which is used by the |
| # contrib version. |
| self._num_gpus_per_worker = num_gpus |
| |
| compute_devices = device_util.local_devices_from_num_gpus(num_gpus) |
| |
| compute_devices = [device_util.canonicalize(d) for d in compute_devices] |
| |
| if parameter_device is None: |
| # If there is only one GPU, put everything on that GPU. Otherwise, place |
| # variables on CPU. |
| if len(compute_devices) == 1: |
| parameter_device = compute_devices[0] |
| else: |
| parameter_device = _LOCAL_CPU |
| |
| self._variable_device = parameter_device |
| self._compute_devices = compute_devices |
| self._parameter_devices = (parameter_device,) |
| self._is_chief = True |
| self._cluster_spec = None |
| self._task_type = None |
| self._task_id = None |
| |
| logging.info( |
| "ParameterServerStrategy (CentralStorageStrategy if you are using a " |
| "single machine) with compute_devices = %r, variable_device = %r", |
| compute_devices, self._variable_device) |
| |
| def _input_workers_with_options(self, options=None): |
| if not options or options.experimental_prefetch_to_device: |
| return input_lib.InputWorkers( |
| [(self._worker_device, self._compute_devices)]) |
| else: |
| return input_lib.InputWorkers( |
| [(self._worker_device, |
| (self._worker_device,) * len(self._compute_devices))]) |
| |
| @property |
| def _input_workers(self): |
| return self._input_workers_with_options() |
| |
| def _validate_colocate_with_variable(self, colocate_with_variable): |
| distribute_utils.validate_colocate(colocate_with_variable, self) |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| return input_lib.get_distributed_dataset( |
| dataset, |
| self._input_workers_with_options(options), |
| self._container_strategy(), |
| num_replicas_in_sync=self._num_replicas_in_sync) |
| |
| def _make_dataset_iterator(self, dataset): |
| return input_lib.DatasetIterator( |
| dataset, |
| self._input_workers, |
| self._container_strategy(), |
| num_replicas_in_sync=self._num_replicas_in_sync) |
| |
| def _make_input_fn_iterator( |
| self, |
| input_fn, |
| replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): |
| """Distributes the dataset to each local GPU.""" |
| if self._cluster_spec: |
| input_pipeline_id = multi_worker_util.id_in_cluster( |
| self._cluster_spec, self._task_type, self._task_id) |
| num_input_pipelines = multi_worker_util.worker_count( |
| self._cluster_spec, self._task_type) |
| else: |
| input_pipeline_id = 0 |
| num_input_pipelines = 1 |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=num_input_pipelines, |
| input_pipeline_id=input_pipeline_id, |
| num_replicas_in_sync=self._num_replicas_in_sync) |
| return input_lib.InputFunctionIterator(input_fn, self._input_workers, |
| [input_context], |
| self._container_strategy()) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| return numpy_dataset.one_host_numpy_dataset( |
| numpy_input, self._input_host_device, session) |
| |
| def _distribute_datasets_from_function(self, dataset_fn, options): |
| if self._cluster_spec: |
| input_pipeline_id = multi_worker_util.id_in_cluster( |
| self._cluster_spec, self._task_type, self._task_id) |
| num_input_pipelines = multi_worker_util.worker_count( |
| self._cluster_spec, self._task_type) |
| else: |
| input_pipeline_id = 0 |
| num_input_pipelines = 1 |
| |
| input_context = distribute_lib.InputContext( |
| num_input_pipelines=num_input_pipelines, |
| input_pipeline_id=input_pipeline_id, |
| num_replicas_in_sync=self._num_replicas_in_sync) |
| |
| return input_lib.get_distributed_datasets_from_function( |
| dataset_fn, |
| self._input_workers_with_options(options), |
| [input_context], |
| self._container_strategy()) |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| per_replica_values = [] |
| for replica_id in range(self._num_replicas_in_sync): |
| per_replica_values.append( |
| value_fn(distribute_lib.ValueContext(replica_id, |
| self._num_replicas_in_sync))) |
| return distribute_utils.regroup(per_replica_values, always_wrap=True) |
| |
| def _broadcast_to(self, tensor, destinations): |
| # This is both a fast path for Python constants, and a way to delay |
| # converting Python values to a tensor until we know what type it |
| # should be converted to. Otherwise we have trouble with: |
| # global_step.assign_add(1) |
| # since the `1` gets broadcast as an int32 but global_step is int64. |
| if isinstance(tensor, (float, int)): |
| return tensor |
| if not cross_device_ops_lib.check_destinations(destinations): |
| # TODO(josh11b): Use current logical device instead of 0 here. |
| destinations = self._compute_devices |
| return self._cross_device_ops.broadcast(tensor, destinations) |
| |
| def _allow_variable_partition(self): |
| return not context.executing_eagerly() |
| |
| # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through |
| # this creator, such as "MutableHashTable". |
| def _create_variable(self, next_creator, **kwargs): |
| if self._num_replicas_in_sync > 1: |
| aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) |
| if aggregation not in ( |
| vs.VariableAggregation.NONE, |
| vs.VariableAggregation.SUM, |
| vs.VariableAggregation.MEAN, |
| vs.VariableAggregation.ONLY_FIRST_REPLICA |
| ): |
| raise ValueError("Invalid variable aggregation mode: " + aggregation + |
| " for variable: " + kwargs["name"]) |
| |
| def var_creator(**kwargs): |
| """Create an AggregatingVariable and fix up collections.""" |
| # Record what collections this variable should be added to. |
| collections = kwargs.pop("collections", None) |
| if collections is None: |
| collections = [ops.GraphKeys.GLOBAL_VARIABLES] |
| kwargs["collections"] = [] |
| |
| # Create and wrap the variable. |
| v = next_creator(**kwargs) |
| wrapped = ps_values.AggregatingVariable(self._container_strategy(), v, |
| aggregation) |
| |
| # Add the wrapped variable to the requested collections. |
| # The handling of eager mode and the global step matches |
| # ResourceVariable._init_from_args(). |
| if not context.executing_eagerly(): |
| g = ops.get_default_graph() |
| # If "trainable" is True, next_creator() will add the contained |
| # variable to the TRAINABLE_VARIABLES collection, so we manually |
| # remove it and replace with the wrapper. We can't set "trainable" |
| # to False for next_creator() since that causes functions like |
| # implicit_gradients to skip those variables. |
| if kwargs.get("trainable", True): |
| collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) |
| l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) |
| if v in l: |
| l.remove(v) |
| g.add_to_collections(collections, wrapped) |
| elif ops.GraphKeys.GLOBAL_STEP in collections: |
| ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) |
| |
| return wrapped |
| else: |
| var_creator = next_creator |
| |
| if "colocate_with" in kwargs: |
| colocate_with = kwargs["colocate_with"] |
| if isinstance(colocate_with, numpy_dataset.SingleDevice): |
| with ops.device(colocate_with.device): |
| return var_creator(**kwargs) |
| with ops.device(None): |
| with ops.colocate_with(colocate_with): |
| return var_creator(**kwargs) |
| |
| with ops.colocate_with(None, ignore_existing=True): |
| with ops.device(self._variable_device): |
| return var_creator(**kwargs) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| return mirrored_run.call_for_each_replica(self._container_strategy(), fn, |
| args, kwargs) |
| |
| def _verify_destinations_not_different_worker(self, destinations): |
| if not self._cluster_spec: |
| return |
| if destinations is None: |
| return |
| for d in cross_device_ops_lib.get_devices_from(destinations): |
| d_spec = tf_device.DeviceSpec.from_string(d) |
| if d_spec.job == self._task_type and d_spec.task != self._task_id: |
| raise ValueError( |
| "Cannot reduce to another worker: %r, current worker is %r" % |
| (d, self._worker_device)) |
| |
| def _gather_to_implementation(self, value, destinations, axis, |
| options): |
| self._verify_destinations_not_different_worker(destinations) |
| if not isinstance(value, values.DistributedValues): |
| return value |
| return self._cross_device_ops._gather( # pylint: disable=protected-access |
| value, |
| destinations=destinations, |
| axis=axis, |
| options=options) |
| |
| def _reduce_to(self, reduce_op, value, destinations, options): |
| self._verify_destinations_not_different_worker(destinations) |
| if not isinstance(value, values.DistributedValues): |
| # pylint: disable=protected-access |
| return cross_device_ops_lib.reduce_non_distributed_value( |
| reduce_op, value, destinations, self._num_replicas_in_sync) |
| return self._cross_device_ops.reduce( |
| reduce_op, value, destinations=destinations, options=options) |
| |
| def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): |
| for _, destinations in value_destination_pairs: |
| self._verify_destinations_not_different_worker(destinations) |
| return self._cross_device_ops.batch_reduce(reduce_op, |
| value_destination_pairs, options) |
| |
| def _select_single_value(self, structured): |
| """Select any single value in `structured`.""" |
| |
| def _select_fn(x): # pylint: disable=g-missing-docstring |
| if isinstance(x, values.Mirrored): |
| if len(x._devices) == 1: # pylint: disable=protected-access |
| return x._primary # pylint: disable=protected-access |
| else: |
| raise ValueError( |
| "You cannot update variable with a Mirrored object with multiple " |
| "components %r when using ParameterServerStrategy. You must " |
| "specify a single value or a Mirrored with a single value." % x) |
| elif isinstance(x, values.PerReplica): |
| raise ValueError( |
| "You cannot update variable with a PerReplica object %r when using " |
| "ParameterServerStrategy. You must specify a single value or a " |
| "Mirrored with a single value" % x) |
| else: |
| return x |
| |
| return nest.map_structure(_select_fn, structured) |
| |
| def _update(self, var, fn, args, kwargs, group): |
| if isinstance(var, ps_values.AggregatingVariable): |
| var = var.get() |
| if not resource_variable_ops.is_resource_variable(var): |
| raise ValueError( |
| "You can not update `var` %r. It must be a Variable." % var) |
| with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): |
| result = fn(var, *self._select_single_value(args), |
| **self._select_single_value(kwargs)) |
| if group: |
| return result |
| else: |
| return nest.map_structure(self._local_results, result) |
| |
| # TODO(yuefengz): does it need to call _select_single_value? |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| with ops.device( |
| colocate_with.device), distribute_lib.UpdateContext(colocate_with): |
| result = fn(*args, **kwargs) |
| if group: |
| return result |
| else: |
| return nest.map_structure(self._local_results, result) |
| |
| def _local_results(self, val): |
| if isinstance(val, values.DistributedValues): |
| return val.values |
| return (val,) |
| |
| def value_container(self, val): |
| if (hasattr(val, "_aggregating_container") and |
| not isinstance(val, ps_values.AggregatingVariable)): |
| wrapper = val._aggregating_container() # pylint: disable=protected-access |
| if wrapper is not None: |
| return wrapper |
| return val |
| |
| def read_var(self, var): |
| # No need to distinguish between normal variables and replica-local |
| # variables. |
| return array_ops.identity(var) |
| |
| def _configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| """Configures the strategy class with `cluster_spec`. |
| |
| The strategy object will be re-initialized if `cluster_spec` is passed to |
| `configure` but was not passed when instantiating the strategy. |
| |
| Args: |
| session_config: Session config object. |
| cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the |
| cluster configurations. |
| task_type: the current task type. |
| task_id: the current task id. |
| |
| Raises: |
| ValueError: if `cluster_spec` is given but `task_type` or `task_id` is |
| not. |
| """ |
| if cluster_spec: |
| # Use the num_gpus_per_worker recorded in constructor since _configure |
| # doesn't take num_gpus. |
| cluster_resolver = SimpleClusterResolver( |
| cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), |
| task_type=task_type, |
| task_id=task_id, |
| num_accelerators={"GPU": self._num_gpus_per_worker}) |
| self._initialize_multi_worker(cluster_resolver) |
| |
| if session_config: |
| session_config.CopyFrom(self._update_config_proto(session_config)) |
| |
| def _update_config_proto(self, config_proto): |
| updated_config = copy.deepcopy(config_proto) |
| if not self._cluster_spec: |
| updated_config.isolate_session_state = True |
| return updated_config |
| |
| updated_config.isolate_session_state = False |
| |
| assert self._task_type |
| assert self._task_id is not None |
| |
| # The device filters prevent communication between workers. |
| del updated_config.device_filters[:] |
| if self._task_type in ["chief", "worker"]: |
| updated_config.device_filters.extend( |
| ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) |
| elif self._task_type == "evaluator": |
| updated_config.device_filters.append( |
| "/job:%s/task:%d" % (self._task_type, self._task_id)) |
| return updated_config |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings.""" |
| return self._cluster_spec is not None |
| |
| @property |
| def _num_replicas_in_sync(self): |
| return len(self._compute_devices) |
| |
| @property |
| def worker_devices(self): |
| return self._compute_devices |
| |
| @property |
| def worker_devices_by_replica(self): |
| return [[d] for d in self._compute_devices] |
| |
| @property |
| def parameter_devices(self): |
| return self._parameter_devices |
| |
| def non_slot_devices(self, var_list): |
| return min(var_list, key=lambda x: x.name) |
| |
| @property |
| def experimental_between_graph(self): |
| # TODO(yuefengz): Should this return False in the local case? |
| return True |
| |
| @property |
| def experimental_should_init(self): |
| return self._is_chief |
| |
| @property |
| def should_checkpoint(self): |
| return self._is_chief |
| |
| @property |
| def should_save_summary(self): |
| return self._is_chief |
| |
| # TODO(priyag): Delete this once all strategies use global batch size. |
| @property |
| def _global_batch_size(self): |
| """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. |
| |
| `make_input_fn_iterator` assumes per-replica batching. |
| |
| Returns: |
| Boolean. |
| """ |
| return True |
| |
| def _get_local_replica_id(self, replica_id_in_sync_group): |
| return replica_id_in_sync_group |
| |
| def _get_replica_id_in_sync_group(self, replica_id): |
| return replica_id |