blob: 8c90af0c300dd469ae781b8ba8350e5bdc0a326f [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class MirroredStrategy implementing tf.distribute.Strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import copy
import functools
import threading
import weakref
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
@contextlib.contextmanager
def _enter_graph(g, eager, creator_stack=None):
"""Context manager for selecting a graph and maybe eager mode."""
if eager:
with g.as_default(), context.eager_mode():
if creator_stack is not None:
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
yield
else:
with g.as_default():
if creator_stack is not None:
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
yield
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
return cpu_device.to_string()
class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
pass
# _call_for_each_replica is not a member of MirroredStrategy so that it is
# not allowed to use anything specific to MirroredStrategy and thus
# can be shared with other distribution strategies.
# TODO(yuefengz): maybe create a common class for those who need to call this
# _call_for_each_replica.
def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
"""Run `fn` in separate threads, once per replica/worker device.
Args:
distribution: the DistributionStrategy object.
device_map: the DeviceMap with the devices to run `fn` on.
fn: function to run (will be run once per replica, each in its own thread).
args: positional arguments for `fn`
kwargs: keyword arguments for `fn`.
Returns:
Merged return value of `fn` across all replicas.
Raises:
RuntimeError: If fn() calls get_replica_context().merge_call() a different
number of times from the available devices.
"""
# TODO(josh11b): Add this option once we add synchronization to variable
# creation. Until then, this is pretty unsafe to use.
run_concurrently = False
if not context.executing_eagerly():
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
shared_variable_store = {}
# TODO(isaprykin): Create these threads once instead of during every call.
threads = []
for index in range(device_map.num_replicas_in_graph):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = _MirroredReplicaThread(
distribution, coord, index, device_map, variable_creator_fn, fn,
values.select_replica(index, args),
values.select_replica(index, kwargs))
threads.append(t)
for t in threads:
t.start()
# When `fn` starts `should_run` event is set on _MirroredReplicaThread
# (`MRT`) threads. The execution waits until
# `MRT.has_paused` is set, which indicates that either `fn` is
# complete or a `get_replica_context().merge_call()` is called. If `fn` is
# complete, then `MRT.done` is set to True. Otherwise, arguments
# of `get_replica_context().merge_call` from all paused threads are grouped
# and the `merge_fn` is performed. Results of the
# `get_replica_context().merge_call` are then set to `MRT.merge_result`.
# Each such `get_replica_context().merge_call` call returns the
# `MRT.merge_result` for that thread when `MRT.should_run` event
# is reset again. Execution of `fn` resumes.
try:
with coord.stop_on_exception():
all_done = False
while not all_done and not coord.should_stop():
done = []
if run_concurrently:
for t in threads:
t.should_run.set()
for t in threads:
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
else:
for t in threads:
t.should_run.set()
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
if coord.should_stop():
return None
all_done = all(done)
if not all_done:
if any(done):
raise RuntimeError("Some replicas made a different number of "
"replica_context().merge_call() calls.")
# get_replica_context().merge_call() case
merge_args = values.regroup(
device_map, tuple(t.merge_args for t in threads))
merge_kwargs = values.regroup(
device_map, tuple(t.merge_kwargs for t in threads))
# We capture the name_scope of the MRT when we call merge_fn
# to ensure that if we have opened a name scope in the MRT,
# it will be respected when executing the merge function. We only
# capture the name_scope from the first MRT and assume it is
# the same for all other MRTs.
mtt_captured_name_scope = threads[0].captured_name_scope
mtt_captured_var_scope = threads[0].captured_var_scope
# Capture and merge the control dependencies from all the threads.
mtt_captured_control_deps = set()
for t in threads:
mtt_captured_control_deps.update(t.captured_control_deps)
with ops.name_scope(mtt_captured_name_scope),\
ops.control_dependencies(mtt_captured_control_deps), \
variable_scope.variable_scope(mtt_captured_var_scope):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for r, t in enumerate(threads):
t.merge_result = values.select_replica(r, merge_result)
finally:
for t in threads:
t.should_run.set()
coord.join(threads)
return values.regroup(device_map, tuple(t.main_result for t in threads))
def _is_device_list_single_worker(devices):
"""Checks whether the devices list is for single or multi-worker.
Args:
devices: a list of device strings, either local or for remote devices.
Returns:
a boolean indicating whether these device strings are for local or for
remote.
Raises:
ValueError: if device strings are not consistent.
"""
specs = (tf_device.DeviceSpec.from_string(d) for d in devices)
num_workers = len({(d.job, d.task, d.replica) for d in specs})
all_local = all(d.job in (None, "localhost") for d in specs)
any_local = any(d.job in (None, "localhost") for d in specs)
if any_local and not all_local:
raise ValueError("Local device string cannot have job specified other "
"than 'localhost'")
if num_workers == 1 and not all_local:
if any(d.task is None for d in specs):
raise ValueError("Remote device string must have task specified.")
return num_workers == 1
def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
"""Returns a device list given a cluster spec."""
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
devices = []
for task_type in ("chief", "worker"):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
if num_gpus_per_worker == 0:
devices.append("/job:%s/task:%d" % (task_type, task_id))
else:
devices.extend([
"/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
for gpu_id in range(num_gpus_per_worker)
])
return devices
def _group_device_list(devices):
"""Groups the devices list by task_type and task_id.
Args:
devices: a list of device strings for remote devices.
Returns:
a dict of list of device strings mapping from task_type to a list of devices
for the task_type in the asceding order of task_id.
"""
assert not _is_device_list_single_worker(devices)
device_dict = {}
for d in devices:
d_spec = tf_device.DeviceSpec.from_string(d)
# Create an entry for the task_type.
if d_spec.job not in device_dict:
device_dict[d_spec.job] = []
# Fill the device list for task_type until it covers the task_id.
while len(device_dict[d_spec.job]) <= d_spec.task:
device_dict[d_spec.job].append([])
device_dict[d_spec.job][d_spec.task].append(d)
return device_dict
def _is_gpu_device(device):
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
def _infer_num_gpus_per_worker(devices):
"""Infers the number of GPUs on each worker.
Currently to make multi-worker cross device ops work, we need all workers to
have the same number of GPUs.
Args:
devices: a list of device strings, can be either local devices or remote
devices.
Returns:
number of GPUs per worker.
Raises:
ValueError if workers have different number of GPUs or GPU indices are not
consecutive and starting from 0.
"""
if _is_device_list_single_worker(devices):
return sum(1 for d in devices if _is_gpu_device(d))
else:
device_dict = _group_device_list(devices)
num_gpus = None
for _, devices_in_task in device_dict.items():
for device_in_task in devices_in_task:
if num_gpus is None:
num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
# Verify other workers have the same number of GPUs.
elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
raise ValueError("All workers should have the same number of GPUs.")
for d in device_in_task:
d_spec = tf_device.DeviceSpec.from_string(d)
if (d_spec.device_type == "GPU" and
d_spec.device_index >= num_gpus):
raise ValueError("GPU `device_index` on a worker should be "
"consecutive and start from 0.")
return num_gpus
def all_local_devices(num_gpus=None):
if num_gpus is None:
num_gpus = context.num_gpus()
return device_util.local_devices_from_num_gpus(num_gpus)
def all_devices():
devices = []
tfconfig = TFConfigClusterResolver()
if tfconfig.cluster_spec().as_dict():
devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
context.num_gpus())
return devices if devices else all_local_devices()
@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes
class MirroredStrategy(distribute_lib.Strategy):
"""Mirrors vars to distribute across multiple devices and machines.
This strategy uses one replica per device and sync replication for its
multi-GPU version.
To use `MirroredStrategy` with multiple workers, please refer to
`tf.distribute.MultiWorkerMirroredStrategy`.
Args:
devices: a list of device strings. If `None`, all available GPUs are used.
If no GPUs are found, CPU is used.
cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
set, nccl will be used by default.
"""
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
super(MirroredStrategy, self).__init__(extended)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"MirroredStrategy")
@tf_export(v1=["distribute.MirroredStrategy"])
class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring
__doc__ = MirroredStrategy.__doc__
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
super(MirroredStrategyV1, self).__init__(extended)
distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
"MirroredStrategy")
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
class MirroredExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of MirroredStrategy."""
def __init__(self, container_strategy, devices=None, cross_device_ops=None):
super(MirroredExtended, self).__init__(container_strategy)
if context.executing_eagerly():
if devices and not _is_device_list_single_worker(devices):
raise RuntimeError("In-graph multi-worker training with "
"`MirroredStrategy` is not supported in eager mode.")
else:
if TFConfigClusterResolver().cluster_spec().as_dict():
# if you are executing in eager mode, only the single machine code
# path is supported.
logging.info("Initializing local devices since in-graph multi-worker "
"training with `MirroredStrategy` is not supported in "
"eager mode. TF_CONFIG will be ignored when "
"when initializing `MirroredStrategy`.")
devices = devices or all_local_devices()
else:
devices = devices or all_devices()
assert devices, ("Got an empty `devices` list and unable to recognize "
"any local devices.")
self._cross_device_ops = cross_device_ops
self._initialize_strategy(devices)
self._cfer_fn_cache = weakref.WeakKeyDictionary()
# TODO(b/128995245): Enable last partial batch support in graph mode.
if ops.executing_eagerly_outside_functions():
self.experimental_enable_get_next_as_optional = True
def _initialize_strategy(self, devices):
# The _initialize_strategy method is intended to be used by distribute
# coordinator as well.
assert devices, "Must specify at least one device."
devices = tuple(device_util.resolve(d) for d in devices)
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument: %s" % (devices,))
if _is_device_list_single_worker(devices):
self._initialize_single_worker(devices)
else:
self._initialize_multi_worker(devices)
def _initialize_single_worker(self, devices):
"""Initializes the object for single-worker training."""
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(self._device_map)
self._inferred_cross_device_ops = None if self._cross_device_ops else (
cross_device_ops_lib.choose_the_best(devices))
self._host_input_device = numpy_dataset.SingleDevice(
self._input_workers.worker_devices[0])
self._is_multi_worker_training = False
logging.info("Using MirroredStrategy with devices %r", devices)
device_spec = tf_device.DeviceSpec.from_string(
self._input_workers.worker_devices[0])
# Ensures when we enter strategy.scope() we use the correct default device
if device_spec.job is not None and device_spec.job != "localhost":
self._default_device = "/job:%s/replica:%d/task:%d" % (
device_spec.job, device_spec.replica, device_spec.task)
def _initialize_multi_worker(self, devices):
"""Initializes the object for multi-worker training."""
device_dict = _group_device_list(devices)
workers = []
worker_devices = []
for job in ("chief", "worker"):
for task in range(len(device_dict.get(job, []))):
worker = "/job:%s/task:%d" % (job, task)
workers.append(worker)
worker_devices.append((worker, device_dict[job][task]))
# Setting `_default_device` will add a device scope in the
# distribution.scope. We set the default device to the first worker. When
# users specify device under distribution.scope by
# with tf.device("/cpu:0"):
# ...
# their ops will end up on the cpu device of its first worker, e.g.
# "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
self._default_device = workers[0]
self._host_input_device = numpy_dataset.SingleDevice(workers[0])
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(
self._device_map, worker_devices)
self._is_multi_worker_training = True
if len(workers) > 1:
if not isinstance(self._cross_device_ops,
cross_device_ops_lib.MultiWorkerAllReduce):
raise ValueError(
"In-graph multi-worker training with `MirroredStrategy` is not "
"supported.")
self._inferred_cross_device_ops = self._cross_device_ops
else:
# TODO(yuefengz): make `choose_the_best` work with device strings
# containing job names.
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
logging.info("Using MirroredStrategy with remote devices %r", devices)
def _get_variable_creator_initial_value(self,
replica_id,
device,
primary_var,
**kwargs):
"""Return the initial value for variables on a replica."""
if replica_id == 0:
return kwargs["initial_value"]
else:
assert primary_var is not None
assert device is not None
assert kwargs is not None
def initial_value_fn():
if context.executing_eagerly() or ops.inside_function():
init_value = primary_var.value()
return array_ops.identity(init_value)
else:
with ops.device(device):
init_value = primary_var.initial_value
return array_ops.identity(init_value)
return initial_value_fn
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device
def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
kwargs["initial_value"] = self._get_variable_creator_initial_value(
replica_id=i,
device=d,
primary_var=value_list[0] if value_list else None,
**kwargs)
if i > 0:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
return values.create_mirrored_variable(
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, values.MirroredVariable,
values.SyncOnReadVariable, *args, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_distributed_variable(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator(
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
input_contexts = []
num_workers = self._input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.InputFunctionIterator(input_fn, self._input_workers,
input_contexts,
self._container_strategy())
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._host_input_device, session)
def _experimental_distribute_datasets_from_function(self, dataset_fn):
input_contexts = []
num_workers = self._input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,
input_contexts,
self._container_strategy())
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None):
if initial_loop_values is None:
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = input_lib.MultiStepContext()
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
fn_result = fn(ctx, iterator.get_next())
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self._local_results(output)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
# We capture the control_flow_context at this point, before we run `fn`
# inside a while_loop. This is useful in cases where we might need to exit
# these contexts and get back to the outer context to do some things, for
# e.g. create an op which should be evaluated only once at the end of the
# loop on the host. One such usage is in creating metrics' value op.
self._outer_control_flow_context = (
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
# Convert the last_step_outputs from a list to the original dict structure
# of last_step_outputs.
last_step_tensor_outputs = loop_result[1:]
last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs)
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that have already been reduced, wrap them in a Mirrored
# container, else in a PerReplica container.
if reduce_op is None:
last_step_tensor_outputs_dict[name] = values.regroup(self._device_map,
output)
else:
assert len(output) == 1
last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
return ctx
def _broadcast_to(self, tensor, destinations):
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
# should be converted to. Otherwise we have trouble with:
# global_step.assign_add(1)
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
if not destinations:
# TODO(josh11b): Use current logical device instead of 0 here.
destinations = values.LogicalDeviceSpec(
device_map=self._device_map, logical_device=0)
return self._get_cross_device_ops().broadcast(tensor, destinations)
def _call_for_each_replica(self, fn, args, kwargs):
if isinstance(fn, def_function.Function):
wrapped = self._cfer_fn_cache.get(fn)
if wrapped is None:
# We need to wrap fn such that it triggers _call_for_each_replica inside
# the tf.function.
wrapped = fn._clone( # pylint: disable=protected-access
python_function=functools.partial(self._call_for_each_replica,
fn.python_function))
self._cfer_fn_cache[fn] = wrapped
return wrapped(args, kwargs)
if context.executing_eagerly():
logging.log_first_n(logging.WARN, "Using %s eagerly has significant "
"overhead currently. We will be working on improving "
"this in the future, but for now please wrap "
"`call_for_each_replica` or `experimental_run` or "
"`experimental_run_v2` inside a tf.function to get "
"the best performance." %
self._container_strategy().__class__.__name__, 5)
return _call_for_each_replica(self._container_strategy(), self._device_map,
fn, args, kwargs)
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
del task_type, task_id
if session_config:
session_config.CopyFrom(self._update_config_proto(session_config))
if cluster_spec:
# TODO(yuefengz): remove the following code once cluster_resolver is
# added.
num_gpus_per_worker = _infer_num_gpus_per_worker(
self._device_map.all_devices)
multi_worker_devices = _cluster_spec_to_device_list(
cluster_spec, num_gpus_per_worker)
self._initialize_multi_worker(multi_worker_devices)
def _update_config_proto(self, config_proto):
updated_config = copy.deepcopy(config_proto)
updated_config.isolate_session_state = True
return updated_config
def _get_cross_device_ops(self):
return self._cross_device_ops or self._inferred_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations):
if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, self._device_map, value, destinations)
return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
return self._get_cross_device_ops().batch_reduce(reduce_op,
value_destination_pairs)
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
updates = []
for i, (d, v) in enumerate(zip(var.devices, var.values)):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs)))
return values.update_regroup(self, self._device_map, updates, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
assert isinstance(colocate_with, tuple)
# TODO(josh11b): In eager mode, use one thread per device.
updates = []
for i, d in enumerate(colocate_with):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates.append(fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs)))
return values.update_regroup(self, self._device_map, updates, group)
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
if isinstance(replica_local_var, values.SyncOnReadVariable):
return replica_local_var._get_cross_replica() # pylint: disable=protected-access
assert isinstance(replica_local_var, values.Mirrored)
return array_ops.identity(replica_local_var.get())
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
return val.values
return (val,)
def value_container(self, val):
return values.value_container(val)
@property
def _num_replicas_in_sync(self):
return self._device_map.num_replicas_in_graph
@property
def worker_devices(self):
return self._device_map.all_devices
@property
def worker_devices_by_replica(self):
return self._device_map.devices_by_replica
@property
def parameter_devices(self):
return self._device_map.all_devices
@property
def experimental_between_graph(self):
return False
@property
def experimental_should_init(self):
return True
@property
def should_checkpoint(self):
return True
@property
def should_save_summary(self):
return True
def non_slot_devices(self, var_list):
del var_list
# TODO(josh11b): Should this be the last logical device instead?
return self._device_map.logical_to_actual_devices(0)
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
`make_input_fn_iterator` assumes per-replica batching.
Returns:
Boolean.
"""
return True
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return False
class _MirroredReplicaThread(threading.Thread):
"""A thread that runs() a function on a device."""
def __init__(self, dist, coord, replica_id, device_map, variable_creator_fn,
fn, args, kwargs):
super(_MirroredReplicaThread, self).__init__()
self.coord = coord
self.distribution = dist
self.device_map = device_map
self.replica_id = replica_id
self.variable_creator_fn = variable_creator_fn
# State needed to run and return the results of `fn`.
self.main_fn = fn
self.main_args = args
self.main_kwargs = kwargs
self.main_result = None
self.done = False
# State needed to run the next merge_call() (if any) requested via
# ReplicaContext.
self.merge_fn = None
self.merge_args = None
self.merge_kwargs = None
self.merge_result = None
self.captured_name_scope = None
self.captured_var_scope = None
# We use a thread.Event for the main thread to signal when this
# thread should start running (`should_run`), and another for
# this thread to transfer control back to the main thread
# (`has_paused`, either when it gets to a
# `get_replica_context().merge_call` or when `fn` returns). In
# either case the event starts cleared, is signaled by calling
# set(). The receiving thread waits for the signal by calling
# wait() and then immediately clearing the event using clear().
self.should_run = threading.Event()
self.has_paused = threading.Event()
# These fields have to do with inheriting various contexts from the
# parent thread:
context.ensure_initialized()
ctx = context.context()
self.in_eager = ctx.executing_eagerly()
self.record_thread_local_summary_state()
self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle)) # pylint: disable=protected-access
self.graph = ops.get_default_graph()
with ops.init_scope():
self._init_in_eager = context.executing_eagerly()
self._init_graph = ops.get_default_graph()
self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access
self._var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
self._name_scope = self.graph.get_name_scope()
if self._name_scope:
self._name_scope += "/"
if self.replica_id > 0:
if not self._name_scope:
self._name_scope = ""
self._name_scope += "replica_%d/" % self.replica_id
def run(self):
self.should_run.wait()
self.should_run.clear()
try:
if self.coord.should_stop():
return
self.restore_thread_local_summary_state()
# TODO(josh11b): Use current logical device instead of 0 here.
with self.coord.stop_on_exception(), \
_enter_graph(self._init_graph, self._init_in_eager), \
_enter_graph(self.graph, self.in_eager,
self._variable_creator_stack), \
context.device_policy(self.context_device_policy), \
MirroredReplicaContext(self.distribution, constant_op.constant(
self.replica_id, dtypes.int32)), \
ops.device(self.device_map.logical_to_actual_devices(0)[
self.replica_id]), \
ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._var_scope, reuse=self.replica_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
self.done = True
finally:
self.has_paused.set()
def record_thread_local_summary_state(self):
"""Record the thread local summary state in self."""
# TODO(slebedev): is this still relevant? the referenced bug is closed.
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
self._summary_step = summary_state.step
self._summary_writer = summary_state.writer
self._summary_recording = summary_state.is_recording
self._summary_recording_distribution_strategy = (
summary_state.is_recording_distribution_strategy)
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_summary_state(self):
"""Restore thread local summary state from self."""
# TODO(slebedev): is this still relevant? the referenced bug is closed.
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
summary_state.step = self._summary_step
summary_state.writer = self._summary_writer
summary_state.is_recording = self._summary_recording
summary_state.is_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
# TODO(b/125892694): restore other fields in EagerContext.
class MirroredReplicaContext(distribute_lib.ReplicaContext):
"""ReplicaContext used in MirroredStrategy.extended.call_for_each_replica().
Opened in `_MirroredReplicaThread`, to allow the user to invoke
`MirroredStrategy`'s specific implementation of `merge_call()`,
which works by delegating the function and its arguments to
the main thread (the one that invoked
`MirroredStrategy.extended.call_for_each_replica()`).
"""
def _merge_call(self, fn, args, kwargs):
"""Delegate to the main thread to actually perform merge_call()."""
t = threading.current_thread() # a _MirroredReplicaThread
t.merge_fn = fn
t.merge_args = args
t.merge_kwargs = kwargs
t.captured_name_scope = t.graph.get_name_scope()
# Adding a "/" at end lets us re-enter this scope later.
if t.captured_name_scope:
t.captured_name_scope += "/"
t.captured_var_scope = variable_scope.get_variable_scope()
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
# NOTE(priyag): Throw an error if there is a merge call in the middle of a
# `fn` passed to call_for_each_replica which changes the graph being used
# while calling `fn`. This can happen when the `fn` is decorated with
# `tf.function` and there is a merge_call in `fn`. This breaks because each
# thread tries to create a distinct tf.function. Each tf.function creation
# takes a lock, and so if there is a merge call in the middle, the lock is
# never released and subsequent replica threads cannot proceed to define
# their own functions. Checking for the graph being the same is one way for
# us to check this didn't happen.
if ops.get_default_graph() != t.graph:
raise RuntimeError(
"`merge_call` called while defining a new graph or a tf.function. "
"This can often happen if the function `fn` passed to "
"`strategy.experimental_run()` is decorated with "
"`@tf.function` (or contains a nested `@tf.function`), and `fn` "
"contains a synchronization point, such as aggregating gradients. "
"This behavior is not yet supported. Instead, please wrap the entire "
"call `strategy.experimental_run(fn)` in a `@tf.function`, and avoid "
"nested `tf.function`s that may potentially cross a synchronization "
"boundary.")
t.has_paused.set()
t.should_run.wait()
t.should_run.clear()
if t.coord.should_stop():
raise _RequestedStop()
return t.merge_result
@property
def devices(self):
distribute_lib.require_replica_context(self)
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
return [self._strategy.extended.worker_devices_by_replica[replica_id]]