| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Class MirroredStrategy implementing tf.distribute.Strategy.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import copy |
| |
| from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import distribute_utils |
| from tensorflow.python.distribute import input_lib |
| from tensorflow.python.distribute import mirrored_run |
| from tensorflow.python.distribute import multi_worker_util |
| from tensorflow.python.distribute import numpy_dataset |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import values |
| from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import tape |
| from tensorflow.python.framework import config |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as tf_device |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import tf_export |
| |
| # TODO(josh11b): Replace asserts in this file with if ...: raise ... |
| |
| |
| def _is_device_list_single_worker(devices): |
| """Checks whether the devices list is for single or multi-worker. |
| |
| Args: |
| devices: a list of device strings or tf.config.LogicalDevice objects, for |
| either local or for remote devices. |
| |
| Returns: |
| a boolean indicating whether these device strings are for local or for |
| remote. |
| |
| Raises: |
| ValueError: if device strings are not consistent. |
| """ |
| specs = [] |
| for d in devices: |
| name = d.name if isinstance(d, context.LogicalDevice) else d |
| specs.append(tf_device.DeviceSpec.from_string(name)) |
| num_workers = len({(d.job, d.task, d.replica) for d in specs}) |
| all_local = all(d.job in (None, "localhost") for d in specs) |
| any_local = any(d.job in (None, "localhost") for d in specs) |
| |
| if any_local and not all_local: |
| raise ValueError("Local device string cannot have job specified other " |
| "than 'localhost'") |
| |
| if num_workers == 1 and not all_local: |
| if any(d.task is None for d in specs): |
| raise ValueError("Remote device string must have task specified.") |
| |
| return num_workers == 1 |
| |
| |
| def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): |
| """Returns a device list given a cluster spec.""" |
| cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) |
| devices = [] |
| for task_type in ("chief", "worker"): |
| for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): |
| if num_gpus_per_worker == 0: |
| devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) |
| else: |
| devices.extend([ |
| "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) |
| for gpu_id in range(num_gpus_per_worker) |
| ]) |
| return devices |
| |
| |
| def _group_device_list(devices): |
| """Groups the devices list by task_type and task_id. |
| |
| Args: |
| devices: a list of device strings for remote devices. |
| |
| Returns: |
| a dict of list of device strings mapping from task_type to a list of devices |
| for the task_type in the ascending order of task_id. |
| """ |
| assert not _is_device_list_single_worker(devices) |
| device_dict = {} |
| |
| for d in devices: |
| d_spec = tf_device.DeviceSpec.from_string(d) |
| |
| # Create an entry for the task_type. |
| if d_spec.job not in device_dict: |
| device_dict[d_spec.job] = [] |
| |
| # Fill the device list for task_type until it covers the task_id. |
| while len(device_dict[d_spec.job]) <= d_spec.task: |
| device_dict[d_spec.job].append([]) |
| |
| device_dict[d_spec.job][d_spec.task].append(d) |
| |
| return device_dict |
| |
| |
| def _is_gpu_device(device): |
| return tf_device.DeviceSpec.from_string(device).device_type == "GPU" |
| |
| |
| def _infer_num_gpus_per_worker(devices): |
| """Infers the number of GPUs on each worker. |
| |
| Currently to make multi-worker cross device ops work, we need all workers to |
| have the same number of GPUs. |
| |
| Args: |
| devices: a list of device strings, can be either local devices or remote |
| devices. |
| |
| Returns: |
| number of GPUs per worker. |
| |
| Raises: |
| ValueError if workers have different number of GPUs or GPU indices are not |
| consecutive and starting from 0. |
| """ |
| if _is_device_list_single_worker(devices): |
| return sum(1 for d in devices if _is_gpu_device(d)) |
| else: |
| device_dict = _group_device_list(devices) |
| num_gpus = None |
| for _, devices_in_task in device_dict.items(): |
| for device_in_task in devices_in_task: |
| if num_gpus is None: |
| num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d)) |
| |
| # Verify other workers have the same number of GPUs. |
| elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)): |
| raise ValueError("All workers should have the same number of GPUs.") |
| |
| for d in device_in_task: |
| d_spec = tf_device.DeviceSpec.from_string(d) |
| if (d_spec.device_type == "GPU" and |
| d_spec.device_index >= num_gpus): |
| raise ValueError("GPU `device_index` on a worker should be " |
| "consecutive and start from 0.") |
| return num_gpus |
| |
| |
| def all_local_devices(num_gpus=None): |
| devices = config.list_logical_devices("GPU") |
| if num_gpus is not None: |
| devices = devices[:num_gpus] |
| return devices or config.list_logical_devices("CPU") |
| |
| |
| def all_devices(): |
| devices = [] |
| tfconfig = TFConfigClusterResolver() |
| if tfconfig.cluster_spec().as_dict(): |
| devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(), |
| context.num_gpus()) |
| return devices if devices else all_local_devices() |
| |
| |
| @tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes |
| class MirroredStrategy(distribute_lib.Strategy): |
| """Synchronous training across multiple replicas on one machine. |
| |
| This strategy is typically used for training on one |
| machine with multiple GPUs. For TPUs, use |
| `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers, |
| please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`. |
| |
| For example, a variable created under a `MirroredStrategy` is a |
| `MirroredVariable`. If no devices are specified in the constructor argument of |
| the strategy then it will use all the available GPUs. If no GPUs are found, it |
| will use the available CPUs. Note that TensorFlow treats all CPUs on a |
| machine as a single device, and uses threads internally for parallelism. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> with strategy.scope(): |
| ... x = tf.Variable(1.) |
| >>> x |
| MirroredVariable:{ |
| 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, |
| 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> |
| } |
| |
| While using distribution strategies, all the variable creation should be done |
| within the strategy's scope. This will replicate the variables across all the |
| replicas and keep them in sync using an all-reduce algorithm. |
| |
| Variables created inside a `MirroredStrategy` which is wrapped with a |
| `tf.function` are still `MirroredVariables`. |
| |
| >>> x = [] |
| >>> @tf.function # Wrap the function with tf.function. |
| ... def create_variable(): |
| ... if not x: |
| ... x.append(tf.Variable(1.)) |
| ... return x[0] |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> with strategy.scope(): |
| ... _ = create_variable() |
| ... print(x[0]) |
| MirroredVariable:{ |
| 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, |
| 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> |
| } |
| |
| Args: |
| devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If |
| `None`, all available GPUs are used. If no GPUs are found, CPU is used. |
| cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not |
| set, `NcclAllReduce()` will be used by default. One would customize this |
| if NCCL isn't available or if a special implementation that exploits |
| the particular hardware is available. |
| """ |
| |
| def __init__(self, devices=None, cross_device_ops=None): |
| extended = MirroredExtended(self, devices=devices, |
| cross_device_ops=cross_device_ops) |
| super(MirroredStrategy, self).__init__(extended) |
| distribute_lib.distribution_strategy_gauge.get_cell("V2").set( |
| "MirroredStrategy") |
| |
| |
| @tf_export(v1=["distribute.MirroredStrategy"]) |
| class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring |
| |
| __doc__ = MirroredStrategy.__doc__ |
| |
| def __init__(self, devices=None, cross_device_ops=None): |
| extended = MirroredExtended( |
| self, devices=devices, cross_device_ops=cross_device_ops) |
| super(MirroredStrategyV1, self).__init__(extended) |
| distribute_lib.distribution_strategy_gauge.get_cell("V1").set( |
| "MirroredStrategy") |
| |
| |
| # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. |
| class MirroredExtended(distribute_lib.StrategyExtendedV1): |
| """Implementation of MirroredStrategy.""" |
| |
| def __init__(self, container_strategy, devices=None, cross_device_ops=None): |
| super(MirroredExtended, self).__init__(container_strategy) |
| if context.executing_eagerly(): |
| if devices and not _is_device_list_single_worker(devices): |
| raise RuntimeError("In-graph multi-worker training with " |
| "`MirroredStrategy` is not supported in eager mode.") |
| else: |
| if TFConfigClusterResolver().cluster_spec().as_dict(): |
| # if you are executing in eager mode, only the single machine code |
| # path is supported. |
| logging.info("Initializing local devices since in-graph multi-worker " |
| "training with `MirroredStrategy` is not supported in " |
| "eager mode. TF_CONFIG will be ignored when " |
| "when initializing `MirroredStrategy`.") |
| devices = devices or all_local_devices() |
| else: |
| devices = devices or all_devices() |
| |
| assert devices, ("Got an empty `devices` list and unable to recognize " |
| "any local devices.") |
| self._cross_device_ops = cross_device_ops |
| self._initialize_strategy(devices) |
| |
| # TODO(b/128995245): Enable last partial batch support in graph mode. |
| if ops.executing_eagerly_outside_functions(): |
| self.experimental_enable_get_next_as_optional = True |
| |
| # Flag to turn on VariablePolicy. |
| self._use_var_policy = False |
| |
| def _initialize_strategy(self, devices): |
| # The _initialize_strategy method is intended to be used by distribute |
| # coordinator as well. |
| assert devices, "Must specify at least one device." |
| devices = tuple(device_util.resolve(d) for d in devices) |
| assert len(set(devices)) == len(devices), ( |
| "No duplicates allowed in `devices` argument: %s" % (devices,)) |
| if _is_device_list_single_worker(devices): |
| self._initialize_single_worker(devices) |
| else: |
| self._initialize_multi_worker(devices) |
| |
| def _initialize_single_worker(self, devices): |
| """Initializes the object for single-worker training.""" |
| self._devices = tuple(device_util.canonicalize(d) for d in devices) |
| self._input_workers_devices = ( |
| (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) |
| |
| self._inferred_cross_device_ops = None if self._cross_device_ops else ( |
| cross_device_ops_lib.choose_the_best(devices)) |
| self._host_input_device = numpy_dataset.SingleDevice( |
| self._input_workers_devices[0][0]) |
| self._is_multi_worker_training = False |
| logging.info("Using MirroredStrategy with devices %r", devices) |
| device_spec = tf_device.DeviceSpec.from_string( |
| self._input_workers_devices[0][0]) |
| # Ensures when we enter strategy.scope() we use the correct default device |
| if device_spec.job is not None and device_spec.job != "localhost": |
| self._default_device = "/job:%s/replica:%d/task:%d" % ( |
| device_spec.job, device_spec.replica, device_spec.task) |
| |
| def _initialize_multi_worker(self, devices): |
| """Initializes the object for multi-worker training.""" |
| device_dict = _group_device_list(devices) |
| workers = [] |
| worker_devices = [] |
| for job in ("chief", "worker"): |
| for task in range(len(device_dict.get(job, []))): |
| worker = "/job:%s/task:%d" % (job, task) |
| workers.append(worker) |
| worker_devices.append((worker, device_dict[job][task])) |
| |
| # Setting `_default_device` will add a device scope in the |
| # distribution.scope. We set the default device to the first worker. When |
| # users specify device under distribution.scope by |
| # with tf.device("/cpu:0"): |
| # ... |
| # their ops will end up on the cpu device of its first worker, e.g. |
| # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. |
| self._default_device = workers[0] |
| self._host_input_device = numpy_dataset.SingleDevice(workers[0]) |
| |
| self._devices = tuple(devices) |
| self._input_workers_devices = worker_devices |
| self._is_multi_worker_training = True |
| |
| if len(workers) > 1: |
| # Grandfather usage in the legacy tests if they're configured properly. |
| if (not isinstance(self._cross_device_ops, |
| cross_device_ops_lib.ReductionToOneDevice) or |
| self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access |
| raise ValueError( |
| "In-graph multi-worker training with `MirroredStrategy` is not " |
| "supported.") |
| self._inferred_cross_device_ops = self._cross_device_ops |
| else: |
| # TODO(yuefengz): make `choose_the_best` work with device strings |
| # containing job names. |
| self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() |
| |
| logging.info("Using MirroredStrategy with remote devices %r", devices) |
| |
| def _input_workers_with_options(self, options=None, input_workers_devices=None): |
| if not input_workers_devices: |
| input_workers_devices = self._input_workers_devices |
| if not options or options.experimental_prefetch_to_device: |
| return input_lib.InputWorkers(input_workers_devices) |
| else: |
| return input_lib.InputWorkers( |
| [(host_device, (host_device,) * len(compute_devices)) for |
| host_device, compute_devices in input_workers_devices]) |
| |
| @property |
| def _input_workers(self): |
| return self._input_workers_with_options() |
| |
| def _get_variable_creator_initial_value(self, |
| replica_id, |
| device, |
| primary_var, |
| **kwargs): |
| """Return the initial value for variables on a replica.""" |
| if replica_id == 0: |
| return kwargs["initial_value"] |
| else: |
| assert primary_var is not None |
| assert device is not None |
| assert kwargs is not None |
| |
| def initial_value_fn(): |
| if context.executing_eagerly() or ops.inside_function(): |
| init_value = primary_var.value() |
| return array_ops.identity(init_value) |
| else: |
| with ops.device(device): |
| init_value = primary_var.initial_value |
| return array_ops.identity(init_value) |
| |
| return initial_value_fn |
| |
| def _create_variable(self, next_creator, **kwargs): |
| """Create a mirrored variable. See `DistributionStrategy.scope`.""" |
| colocate_with = kwargs.pop("colocate_with", None) |
| if colocate_with is None: |
| devices = self._devices |
| elif isinstance(colocate_with, numpy_dataset.SingleDevice): |
| with ops.device(colocate_with.device): |
| return next_creator(**kwargs) |
| else: |
| devices = colocate_with._devices # pylint: disable=protected-access |
| |
| def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring |
| value_list = [] |
| for i, d in enumerate(devices): |
| with ops.device(d): |
| kwargs["initial_value"] = self._get_variable_creator_initial_value( |
| replica_id=i, |
| device=d, |
| primary_var=value_list[0] if value_list else None, |
| **kwargs) |
| if i > 0: |
| # Give replicas meaningful distinct names: |
| var0name = value_list[0].name.split(":")[0] |
| # We append a / to variable names created on replicas with id > 0 to |
| # ensure that we ignore the name scope and instead use the given |
| # name as the absolute name of the variable. |
| kwargs["name"] = "%s/replica_%d/" % (var0name, i) |
| with context.device_policy(context.DEVICE_PLACEMENT_SILENT): |
| # Don't record operations (e.g. other variable reads) during |
| # variable creation. |
| with tape.stop_recording(): |
| v = next_creator(**kwargs) |
| assert not isinstance(v, values.DistributedVariable) |
| value_list.append(v) |
| return value_list |
| |
| return distribute_utils.create_mirrored_variable( |
| self._container_strategy(), _real_mirrored_creator, |
| distribute_utils.VARIABLE_CLASS_MAPPING, |
| distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs) |
| |
| def _validate_colocate_with_variable(self, colocate_with_variable): |
| distribute_utils.validate_colocate_distributed_variable( |
| colocate_with_variable, self) |
| |
| def _make_dataset_iterator(self, dataset): |
| return input_lib.DatasetIterator( |
| dataset, |
| self._input_workers, |
| self._container_strategy(), |
| split_batch_by=self._num_replicas_in_sync) |
| |
| def _make_input_fn_iterator( |
| self, |
| input_fn, |
| replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): |
| input_contexts = [] |
| num_workers = self._input_workers.num_workers |
| for i in range(num_workers): |
| input_contexts.append(distribute_lib.InputContext( |
| num_input_pipelines=num_workers, |
| input_pipeline_id=i, |
| num_replicas_in_sync=self._num_replicas_in_sync)) |
| return input_lib.InputFunctionIterator(input_fn, self._input_workers, |
| input_contexts, |
| self._container_strategy()) |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| if options and options.replication_mode == distribute_lib.InputReplicationMode.PER_REPLICA: |
| raise RuntimeError("InputReplicationMode.PER_REPLICA " |
| "is only supported in `experimental_distribute_datasets_from_function`.") |
| return input_lib.get_distributed_dataset( |
| dataset, |
| self._input_workers_with_options(options), |
| self._container_strategy(), |
| split_batch_by=self._num_replicas_in_sync) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| return numpy_dataset.one_host_numpy_dataset( |
| numpy_input, self._host_input_device, session) |
| |
| def _distribute_datasets_from_function(self, dataset_fn, |
| options): |
| if options.replication_mode == distribute_lib.InputReplicationMode.PER_REPLICA: |
| self._input_workers_devices = ( |
| tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) for d in self._devices)) |
| input_workers = self._input_workers_with_options( |
| None, self._input_workers_devices) |
| else: |
| input_workers = self._input_workers_with_options( |
| options, self._input_workers_devices) |
| input_contexts = [] |
| num_workers = input_workers.num_workers |
| for i in range(num_workers): |
| input_contexts.append(distribute_lib.InputContext( |
| num_input_pipelines=num_workers, |
| input_pipeline_id=i, |
| num_replicas_in_sync=self._num_replicas_in_sync)) |
| |
| return input_lib.get_distributed_datasets_from_function( |
| dataset_fn, |
| input_workers, |
| input_contexts, |
| self._container_strategy(), |
| options.replication_mode) |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| per_replica_values = [] |
| for replica_id in range(self._num_replicas_in_sync): |
| per_replica_values.append(value_fn( |
| distribute_lib.ValueContext(replica_id, |
| self._num_replicas_in_sync))) |
| return distribute_utils.regroup(per_replica_values, always_wrap=True) |
| |
| # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. |
| def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, |
| initial_loop_values=None): |
| if initial_loop_values is None: |
| initial_loop_values = {} |
| initial_loop_values = nest.flatten(initial_loop_values) |
| |
| ctx = input_lib.MultiStepContext() |
| def body(i, *args): |
| """A wrapper around `fn` to create the while loop body.""" |
| del args |
| fn_result = fn(ctx, iterator.get_next()) |
| for (name, output) in ctx.last_step_outputs.items(): |
| # Convert all outputs to tensors, potentially from `DistributedValues`. |
| ctx.last_step_outputs[name] = self._local_results(output) |
| flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) |
| with ops.control_dependencies([fn_result]): |
| return [i + 1] + flat_last_step_outputs |
| |
| # We capture the control_flow_context at this point, before we run `fn` |
| # inside a while_loop. This is useful in cases where we might need to exit |
| # these contexts and get back to the outer context to do some things, for |
| # e.g. create an op which should be evaluated only once at the end of the |
| # loop on the host. One such usage is in creating metrics' value op. |
| self._outer_control_flow_context = ( |
| ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access |
| |
| cond = lambda i, *args: i < iterations |
| i = constant_op.constant(0) |
| loop_result = control_flow_ops.while_loop( |
| cond, body, [i] + initial_loop_values, name="", |
| parallel_iterations=1, back_prop=False, swap_memory=False, |
| return_same_structure=True) |
| del self._outer_control_flow_context |
| |
| ctx.run_op = control_flow_ops.group(loop_result) |
| |
| # Convert the last_step_outputs from a list to the original dict structure |
| # of last_step_outputs. |
| last_step_tensor_outputs = loop_result[1:] |
| last_step_tensor_outputs_dict = nest.pack_sequence_as( |
| ctx.last_step_outputs, last_step_tensor_outputs) |
| |
| for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access |
| output = last_step_tensor_outputs_dict[name] |
| # For outputs that have already been reduced, wrap them in a Mirrored |
| # container, else in a PerReplica container. |
| if reduce_op is None: |
| last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) |
| else: |
| assert len(output) == 1 |
| last_step_tensor_outputs_dict[name] = output[0] |
| |
| ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access |
| return ctx |
| |
| def _broadcast_to(self, tensor, destinations): |
| # This is both a fast path for Python constants, and a way to delay |
| # converting Python values to a tensor until we know what type it |
| # should be converted to. Otherwise we have trouble with: |
| # global_step.assign_add(1) |
| # since the `1` gets broadcast as an int32 but global_step is int64. |
| if isinstance(tensor, (float, int)): |
| return tensor |
| # TODO(josh11b): In eager mode, use one thread per device, or async mode. |
| if not destinations: |
| # TODO(josh11b): Use current logical device instead of 0 here. |
| destinations = self._devices |
| return self._get_cross_device_ops(tensor).broadcast(tensor, destinations) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| return mirrored_run.call_for_each_replica( |
| self._container_strategy(), fn, args, kwargs) |
| |
| def _configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| del task_type, task_id |
| |
| if session_config: |
| session_config.CopyFrom(self._update_config_proto(session_config)) |
| |
| if cluster_spec: |
| # TODO(yuefengz): remove the following code once cluster_resolver is |
| # added. |
| num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices) |
| multi_worker_devices = _cluster_spec_to_device_list( |
| cluster_spec, num_gpus_per_worker) |
| self._initialize_multi_worker(multi_worker_devices) |
| |
| def _update_config_proto(self, config_proto): |
| updated_config = copy.deepcopy(config_proto) |
| updated_config.isolate_session_state = True |
| return updated_config |
| |
| def _get_cross_device_ops(self, value): |
| del value # Unused. |
| return self._cross_device_ops or self._inferred_cross_device_ops |
| |
| def _reduce_to(self, reduce_op, value, destinations, experimental_hints): |
| if (distribute_utils.is_mirrored(value) and |
| reduce_op == reduce_util.ReduceOp.MEAN): |
| return value |
| assert not distribute_utils.is_mirrored(value) |
| if not isinstance(value, values.DistributedValues): |
| # This function handles reducing values that are not PerReplica or |
| # Mirrored values. For example, the same value could be present on all |
| # replicas in which case `value` would be a single value or value could |
| # be 0. |
| return cross_device_ops_lib.reduce_non_distributed_value( |
| reduce_op, value, destinations, self._num_replicas_in_sync) |
| return self._get_cross_device_ops(value).reduce( |
| reduce_op, |
| value, |
| destinations=destinations, |
| experimental_hints=experimental_hints) |
| |
| def _batch_reduce_to(self, reduce_op, value_destination_pairs, |
| experimental_hints): |
| cross_device_ops = None |
| for value, _ in value_destination_pairs: |
| if cross_device_ops is None: |
| cross_device_ops = self._get_cross_device_ops(value) |
| elif cross_device_ops is not self._get_cross_device_ops(value): |
| raise ValueError("inputs to batch_reduce_to must be either all on the " |
| "the host or all on the compute devices") |
| return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs, |
| experimental_hints) |
| |
| def _update(self, var, fn, args, kwargs, group): |
| # TODO(josh11b): In eager mode, use one thread per device. |
| assert isinstance(var, values.DistributedVariable) |
| updates = [] |
| for i, v in enumerate(var.values): |
| name = "update_%d" % i |
| with ops.device(v.device), \ |
| distribute_lib.UpdateContext(i), \ |
| ops.name_scope(name): |
| # If args and kwargs are not mirrored, the value is returned as is. |
| updates.append( |
| fn(v, *distribute_utils.select_replica_mirrored(i, args), |
| **distribute_utils.select_replica_mirrored(i, kwargs))) |
| return distribute_utils.update_regroup(self, updates, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| assert isinstance(colocate_with, tuple) |
| # TODO(josh11b): In eager mode, use one thread per device. |
| updates = [] |
| for i, d in enumerate(colocate_with): |
| name = "update_%d" % i |
| with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): |
| updates.append( |
| fn(*distribute_utils.select_replica_mirrored(i, args), |
| **distribute_utils.select_replica_mirrored(i, kwargs))) |
| return distribute_utils.update_regroup(self, updates, group) |
| |
| def read_var(self, replica_local_var): |
| """Read the aggregate value of a replica-local variable.""" |
| # pylint: disable=protected-access |
| if distribute_utils.is_sync_on_read(replica_local_var): |
| return replica_local_var._get_cross_replica() |
| assert distribute_utils.is_mirrored(replica_local_var) |
| return array_ops.identity(replica_local_var._get()) |
| # pylint: enable=protected-access |
| |
| def _local_results(self, val): |
| if isinstance(val, values.DistributedValues): |
| return val._values # pylint: disable=protected-access |
| return (val,) |
| |
| def value_container(self, val): |
| return distribute_utils.value_container(val) |
| |
| @property |
| def _num_replicas_in_sync(self): |
| return len(self._devices) |
| |
| @property |
| def worker_devices(self): |
| return self._devices |
| |
| @property |
| def worker_devices_by_replica(self): |
| return [[d] for d in self._devices] |
| |
| @property |
| def parameter_devices(self): |
| return self.worker_devices |
| |
| @property |
| def experimental_between_graph(self): |
| return False |
| |
| @property |
| def experimental_should_init(self): |
| return True |
| |
| @property |
| def should_checkpoint(self): |
| return True |
| |
| @property |
| def should_save_summary(self): |
| return True |
| |
| def non_slot_devices(self, var_list): |
| del var_list |
| # TODO(josh11b): Should this be the last logical device instead? |
| return self._devices |
| |
| # TODO(priyag): Delete this once all strategies use global batch size. |
| @property |
| def _global_batch_size(self): |
| """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. |
| |
| `make_input_fn_iterator` assumes per-replica batching. |
| |
| Returns: |
| Boolean. |
| """ |
| return True |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings.""" |
| return False |