blob: 69f602737bd71ae91cfd35a6d235bdaeb8c0d2fe [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
"""Library for running a computation across multiple devices.
The intent of this library is that you can write an algorithm in a stylized way
and it will be usable with a variety of different `tf.distribute.Strategy`
implementations. Each descendant will implement a different strategy for
distributing the algorithm across multiple devices/machines. Furthermore, these
changes can be hidden inside the specific layers and other library classes that
need special treatment to run in a distributed setting, so that most users'
model definition code can run unchanged. The `tf.distribute.Strategy` API works
the same way with eager and graph execution.
*Guides*
* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
*Tutorials*
* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
The tutorials cover how to use `tf.distribute.Strategy` to do distributed
training with native Keras APIs, custom training loops,
and Estimator APIs. They also cover how to save/load model when using
`tf.distribute.Strategy`.
*Glossary*
* _Data parallelism_ is where we run multiple copies of the model
on different slices of the input data. This is in contrast to
_model parallelism_ where we divide up a single copy of a model
across multiple devices.
Note: we only support data parallelism for now, but
hope to add support for model parallelism in the future.
* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
devices on a single machine, or be connected to devices on multiple
machines. Devices used to run computations are called _worker devices_.
Devices used to store variables are _parameter devices_. For some strategies,
such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
will be the same (see mirrored variables below). For others they will be
different. For example, `tf.distribute.experimental.CentralStorageStrategy`
puts the variables on a single device (which may be a worker device or may be
the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
variables on separate machines called _parameter servers_ (see below).
* A _replica_ is one copy of the model, running on one slice of the
input data. Right now each replica is executed on its own
worker device, but once we add support for model parallelism
a replica may span multiple worker devices.
* A _host_ is the CPU device on a machine with worker devices, typically
used for running input pipelines.
* A _worker_ is defined to be the physical machine(s) containing the physical
devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
worker may contain one or more replicas, but contains at least one
replica. Typically one worker will correspond to one machine, but in the case
of very large models with model parallelism, one worker may span multiple
machines. We typically run one input pipeline per worker, feeding all the
replicas on that worker.
* _Synchronous_, or more commonly _sync_, training is where the updates from
each replica are aggregated together before updating the model variables. This
is in contrast to _asynchronous_, or _async_ training, where each replica
updates the model variables independently. You may also have replicas
partitioned into groups which are in sync within each group but async between
groups.
* _Parameter servers_: These are machines that hold a single copy of
parameters/variables, used by some strategies (right now just
`tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
to operate on a variable retrieve it at the beginning of a step and send an
update to be applied at the end of the step. These can in principle support
either sync or async training, but right now we only have support for async
training with parameter servers. Compare to
`tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
on a single device on the same machine (and does sync training), and
`tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
(see below).
* _Replica context_ vs. _Cross-replica context_ vs _Update context_
A _replica context_ applies
when you execute the computation function that was called with `strategy.run`.
Conceptually, you're in replica context when executing the computation
function that is being replicated.
An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
call.
An _cross-replica context_ is entered when you enter a `strategy.scope`. This
is useful for calling `tf.distribute.Strategy` methods which operate across
the replicas (like `reduce_to()`). By default you start in a _replica context_
(the "default single _replica context_") and then some methods can switch you
back and forth.
* _Distributed value_: Distributed value is represented by the base class
`tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
to represent values on multiple devices, and it contains a map from replica id
to values. Two representative kinds of `tf.distribute.DistributedValues` are
"PerReplica" and "Mirrored" values.
"PerReplica" values exist on the worker
devices, with a different value for each replica. They are produced by
iterating through a distributed dataset returned by
`tf.distribute.Strategy.experimental_distribute_dataset` and
`tf.distribute.Strategy.distribute_datasets_from_function`. They
are also the typical result returned by
`tf.distribute.Strategy.run`.
"Mirrored" values are like "PerReplica" values, except we know that the value
on all replicas are the same. We can safely read a "Mirrored" value in a
cross-replica context by using the value on any replica.
* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
replicas, like `strategy.run(fn, args=[w])` with an
argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
`strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return
values from `fn()`, which leads to one common object if the returned values
are the same object from every replica, or a `DistributedValues` object
otherwise.
* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
multiple values into one value, like "sum" or "mean". If a strategy is doing
sync training, we will perform a reduction on the gradients to a parameter
from all replicas before applying the update. _All-reduce_ is an algorithm for
performing a reduction on values from multiple devices and making the result
available on all of those devices.
* _Mirrored variables_: These are variables that are created on multiple
devices, where we keep the variables in sync by applying the same
updates to every copy. Mirrored variables are created with
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
Normally they are only used in synchronous training.
* _SyncOnRead variables_
_SyncOnRead variables_ are created by
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
they are created on multiple devices. In replica context, each
component variable on the local replica can perform reads and writes without
synchronization with each other. When the
_SyncOnRead variable_ is read in cross-replica context, the values from
component variables are aggregated and returned.
_SyncOnRead variables_ bring a lot of custom configuration difficulty to the
underlying logic, so we do not encourage users to instantiate and use
_SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
variables_ for use cases such as batch norm and metrics. For performance
reasons, we often don't need to keep these statistics in sync every step and
they can be accumulated on each replica independently. The only time we want
to sync them is reporting or checkpointing, which typically happens in
cross-replica context. _SyncOnRead variables_ are also often used by advanced
users who want to control when variable values are aggregated. For example,
users sometimes want to maintain gradients independently on each replica for a
couple of steps without aggregation.
* _Distribute-aware layers_
Layers are generally called in a replica context, except when defining a
Keras functional model. `tf.distribute.in_cross_replica_context` will let you
determine which case you are in. If in a replica context,
the `tf.distribute.get_replica_context` function will return the default
replica context outside a strategy scope, `None` within a strategy scope, and
a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
`tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
`all_reduce` method for aggregating across all replicas.
Note that we provide a default version of `tf.distribute.Strategy` that is
used when no other strategy is in scope, that provides the same API with
reasonable default behavior.
"""
# pylint: enable=line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import enum # pylint: disable=g-bad-import-order
import functools
import threading
import weakref
import six
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context as eager_context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
# ------------------------------------------------------------------------------
# Context tracking whether in a strategy.update() or .update_non_slot() call.
_update_replica_id = threading.local()
def get_update_replica_id():
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
try:
return _update_replica_id.current
except AttributeError:
return None
class UpdateContext(object):
"""Context manager when you are in `update()` or `update_non_slot()`."""
__slots__ = ["_replica_id", "_old_replica_id"]
def __init__(self, replica_id):
self._replica_id = replica_id
self._old_replica_id = None
def __enter__(self):
self._old_replica_id = get_update_replica_id()
_update_replica_id.current = self._replica_id
def __exit__(self, exception_type, exception_value, traceback):
del exception_type, exception_value, traceback
_update_replica_id.current = self._old_replica_id
# ------------------------------------------------------------------------------
# Public utility functions.
@tf_export(v1=["distribute.get_loss_reduction"])
def get_loss_reduction():
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction.
This is used to decide whether loss should be scaled in optimizer (used only
for estimator + v1 optimizer use case).
Returns:
`tf.distribute.ReduceOp` corresponding to the last loss reduction for
estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
"""
if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access
# If we are not in Estimator context then return 'SUM'. We do not need to
# scale loss in the optimizer.
return reduce_util.ReduceOp.SUM
last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if (last_reduction == losses_impl.Reduction.SUM or
last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM
return reduce_util.ReduceOp.SUM
return reduce_util.ReduceOp.MEAN
# ------------------------------------------------------------------------------
# Internal API for validating the current thread mode
def _require_cross_replica_or_default_context_extended(extended,
error_message=None):
"""Verify in cross-replica context."""
context = _get_per_thread_mode()
cross_replica = context.cross_replica_context
if cross_replica is not None and cross_replica.extended is extended:
return
if context is _get_default_replica_mode():
return
strategy = extended._container_strategy() # pylint: disable=protected-access
# We have an error to report, figure out the right message.
if context.strategy is not strategy:
_wrong_strategy_scope(strategy, context)
assert cross_replica is None
if not error_message:
error_message = ("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()")
raise RuntimeError(error_message)
def _wrong_strategy_scope(strategy, context):
# Figure out the right error message.
if not distribution_strategy_context.has_strategy():
raise RuntimeError(
'Need to be inside "with strategy.scope()" for %s' %
(strategy,))
else:
raise RuntimeError(
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(context.strategy, strategy))
def require_replica_context(replica_ctx):
"""Verify in `replica_ctx` replica context."""
context = _get_per_thread_mode()
if context.replica_context is replica_ctx: return
# We have an error to report, figure out the right message.
if context.replica_context is None:
raise RuntimeError("Need to be inside `call_for_each_replica()`")
if context.strategy is replica_ctx.strategy:
# Two different ReplicaContexts with the same tf.distribute.Strategy.
raise RuntimeError("Mismatching ReplicaContext.")
raise RuntimeError(
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
(context.strategy, replica_ctx.strategy))
def _require_strategy_scope_strategy(strategy):
"""Verify in a `strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy is strategy: return
_wrong_strategy_scope(strategy, context)
def _require_strategy_scope_extended(extended):
"""Verify in a `distribution_strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy.extended is extended: return
# Report error.
strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_strategy_scope(strategy, context)
# ------------------------------------------------------------------------------
# Internal context managers used to implement the DistributionStrategy
# base class
class _CurrentDistributionContext(object):
"""Context manager setting the current `tf.distribute.Strategy`.
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self,
strategy,
var_creator_scope,
var_scope=None,
default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
self._device_scope = ops.device(default_device)
else:
self._device_scope = None
self._same_scope_again_count = 0
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
_require_cross_replica_or_default_context_extended(
self._context.strategy.extended)
self._same_scope_again_count += 1
else:
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
if self._device_scope:
self._device_scope.__enter__()
return self._context.strategy
def __exit__(self, exception_type, exception_value, traceback):
if self._same_scope_again_count > 0:
self._same_scope_again_count -= 1
return
if self._device_scope:
try:
self._device_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Device scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
if self._var_scope:
try:
self._var_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
_pop_per_thread_mode()
# TODO(yuefengz): add more replication modes.
@tf_export("distribute.InputReplicationMode")
class InputReplicationMode(enum.Enum):
"""Replication mode for input function.
* `PER_WORKER`: The input function will be called on each worker
independently, creating as many input pipelines as number of workers.
Replicas will dequeue from the local Dataset on their worker.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
* `PER_REPLICA`: The input function will be called on each replica separately.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
"""
PER_WORKER = "PER_WORKER"
PER_REPLICA = "PER_REPLICA"
@tf_export("distribute.InputContext")
class InputContext(object):
"""A class wrapping information needed by an input function.
This is a context class that is passed to the user's input function and
contains information about the compute replicas and input pipelines. The
number of compute replicas (in sync training) helps compute the local batch
size from the desired global batch size for each replica. The input pipeline
information can be used to return a different subset of the input in each
replica (for e.g. shard the input pipeline, use a different input
source etc).
"""
__slots__ = [
"_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
]
def __init__(self,
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=1):
"""Initializes an InputContext object.
Args:
num_input_pipelines: the number of input pipelines in a cluster.
input_pipeline_id: the current input pipeline id, should be an int in
[0,`num_input_pipelines`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._num_input_pipelines = num_input_pipelines
self._input_pipeline_id = input_pipeline_id
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def input_pipeline_id(self):
"""Returns the input pipeline ID."""
return self._input_pipeline_id
@property
def num_input_pipelines(self):
"""Returns the number of input pipelines."""
return self._num_input_pipelines
def get_per_replica_batch_size(self, global_batch_size):
"""Returns the per-replica batch size.
Args:
global_batch_size: the global batch size which should be divisible by
`num_replicas_in_sync`.
Returns:
the per-replica batch size.
Raises:
ValueError: if `global_batch_size` not divisible by
`num_replicas_in_sync`.
"""
if global_batch_size % self._num_replicas_in_sync != 0:
raise ValueError("The `global_batch_size` %r is not divisible by "
"`num_replicas_in_sync` %r " %
(global_batch_size, self._num_replicas_in_sync))
return global_batch_size // self._num_replicas_in_sync
def __str__(self):
return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
self.input_pipeline_id, self.num_input_pipelines)
@tf_export("distribute.experimental.ValueContext", v1=[])
class ValueContext(object):
"""A class wrapping information needed by a distribute function.
This is a context class that is passed to the `value_fn` in
`strategy.experimental_distribute_values_from_function` and contains
information about the compute replicas. The `num_replicas_in_sync` and
`replica_id` can be used to customize the value on each replica.
Example usage:
1. Directly constructed.
>>> def value_fn(context):
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
>>> context = tf.distribute.experimental.ValueContext(
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
>>> per_replica_value = value_fn(context)
>>> per_replica_value
0.5
2. Passed in by `experimental_distribute_values_from_function`.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def value_fn(value_context):
... return value_context.num_replicas_in_sync
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(2, 2)
"""
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
def __init__(self,
replica_id_in_sync_group=0,
num_replicas_in_sync=1):
"""Initializes an ValueContext object.
Args:
replica_id_in_sync_group: the current replica_id, should be an int in
[0,`num_replicas_in_sync`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._replica_id_in_sync_group = replica_id_in_sync_group
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def replica_id_in_sync_group(self):
"""Returns the replica ID."""
return self._replica_id_in_sync_group
def __str__(self):
return (("tf.distribute.ValueContext(replica id {}, "
" total replicas in sync: ""{})")
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
@tf_export("distribute.RunOptions")
class RunOptions(
collections.namedtuple("RunOptions", [
"experimental_enable_dynamic_batch_size",
"experimental_bucketizing_dynamic_shape",
"experimental_xla_options",
])):
"""Run options for `strategy.run`.
This can be used to hold some strategy specific configs.
Attributes:
experimental_enable_dynamic_batch_size: Boolean. Only applies to
TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
padder to support dynamic batch size for the inputs. Otherwise only static
shape inputs are allowed.
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will automatic
bucketize inputs passed into `run` if the input shape is
dynamic. This is a performance optimization to reduce XLA recompilation,
which should not have impact on correctness.
experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
"""
def __new__(cls,
experimental_enable_dynamic_batch_size=True,
experimental_bucketizing_dynamic_shape=False,
experimental_xla_options=None):
return super(RunOptions,
cls).__new__(cls, experimental_enable_dynamic_batch_size,
experimental_bucketizing_dynamic_shape,
experimental_xla_options)
@tf_export("distribute.InputOptions", v1=[])
class InputOptions(
collections.namedtuple("InputOptions", [
"experimental_fetch_to_device",
"experimental_replication_mode",
"experimental_place_dataset_on_device",
"experimental_per_replica_buffer_size",
])):
"""Run options for `experimental_distribute_dataset(s_from_function)`.
This can be used to hold some strategy specific configs.
```python
# Setup TPUStrategy
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
dataset = tf.data.Dataset.range(16)
distributed_dataset_on_host = (
strategy.experimental_distribute_dataset(
dataset,
tf.distribute.InputOptions(
experimental_replication_mode=
experimental_replication_mode.PER_WORKER,
experimental_place_dataset_on_device=False,
experimental_per_replica_buffer_size=1)))
```
Attributes:
experimental_fetch_to_device: Boolean. If True, dataset
elements will be prefetched to accelerator device memory. When False,
dataset elements are prefetched to host device memory. Must be False when
using TPUEmbedding API. experimental_fetch_to_device can only be used
with experimental_replication_mode=PER_WORKER. Default behavior is same as
setting it to True.
experimental_replication_mode: Replication mode for the input function.
Currently, the InputReplicationMode.PER_REPLICA is only supported with
tf.distribute.MirroredStrategy.
experimental_distribute_datasets_from_function.
The default value is InputReplicationMode.PER_WORKER.
experimental_place_dataset_on_device: Boolean. Default to False. When True,
dataset will be placed on the device, otherwise it will remain on the
host. experimental_place_dataset_on_device=True can only be used with
experimental_replication_mode=PER_REPLICA
experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
prefetch buffer size in the replica device memory. Users can set it
to 0 to completely disable prefetching behavior, or a number greater than
1 to enable larger buffer size. Note that this option is still
valid with `experimental_fetch_to_device=False`.
"""
def __new__(cls,
experimental_fetch_to_device=None,
experimental_replication_mode=InputReplicationMode.PER_WORKER,
experimental_place_dataset_on_device=False,
experimental_per_replica_buffer_size=1):
if experimental_fetch_to_device is None:
experimental_fetch_to_device = True
return super(InputOptions,
cls).__new__(cls, experimental_fetch_to_device,
experimental_replication_mode,
experimental_place_dataset_on_device,
experimental_per_replica_buffer_size)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.
# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
# v1/v2 Strategy, add to implementing classes of StrategyBase.
# pylint: disable=line-too-long
class StrategyBase(object):
"""A state & compute distribution policy on a list of devices.
See [the guide](https://www.tensorflow.org/guide/distributed_training)
for overview and examples. See `tf.distribute.StrategyExtended` and
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
for a glossary of concepts mentioned on this page such as "per-replica",
_replica_, and _reduce_.
In short:
* To use it with Keras `compile`/`fit`,
[please
read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
* You may pass descendant of `tf.distribute.Strategy` to
`tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
should distribute its computation. See
[guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
* Otherwise, use `tf.distribute.Strategy.scope` to specify that a
strategy should be used when building an executing your model.
(This puts you in the "cross-replica context" for this strategy, which
means the strategy is put in control of things like variable placement.)
* If you are writing a custom training loop, you will need to call a few more
methods,
[see the
guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
* Start by creating a `tf.data.Dataset` normally.
* Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
a `tf.data.Dataset` to something that produces "per-replica" values.
If you want to manually specify how the dataset should be partitioned
across replicas, use
`tf.distribute.Strategy.distribute_datasets_from_function`
instead.
* Use `tf.distribute.Strategy.run` to run a function
once per replica, taking values that may be "per-replica" (e.g.
from a `tf.distribute.DistributedDataset` object) and returning
"per-replica" values.
This function is executed in "replica context", which means each
operation is performed separately on each replica.
* Finally use a method (such as `tf.distribute.Strategy.reduce`) to
convert the resulting "per-replica" values into ordinary `Tensor`s.
A custom training loop can be as simple as:
```
with my_strategy.scope():
@tf.function
def distribute_train_epoch(dataset):
def replica_fn(input):
# process input and return result
return result
total_result = 0
for x in dataset:
per_replica_result = my_strategy.run(replica_fn, args=(x,))
total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_result, axis=None)
return total_result
dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
for _ in range(EPOCHS):
train_result = distribute_train_epoch(dist_dataset)
```
This takes an ordinary `dataset` and `replica_fn` and runs it
distributed using a particular `tf.distribute.Strategy` named
`my_strategy` above. Any variables created in `replica_fn` are created
using `my_strategy`'s policy, and library functions called by
`replica_fn` can use the `get_replica_context()` API to implement
distributed-specific behavior.
You can use the `reduce` API to aggregate results across replicas and use
this as a return value from one iteration over a
`tf.distribute.DistributedDataset`. Or
you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
accumulate metrics across steps in a given epoch.
See the
[custom training loop
tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
for a more detailed example.
Note: `tf.distribute.Strategy` currently does not support TensorFlow's
partitioned variables (where a single variable is split across multiple
devices) at this time.
"""
# pylint: enable=line-too-long
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
def __init__(self, extended):
self._extended = extended
# Flag that is used to indicate whether distribution strategy is used with
# Estimator. This is required for backward compatibility of loss scaling
# when using v1 optimizer with estimator.
self._scale_loss_for_estimator = False
if not hasattr(extended, "_retrace_functions_for_each_device"):
# pylint: disable=protected-access
# `extended._retrace_functions_for_each_device` dictates
# whether the same function will be retraced when it is called on
# different devices.
try:
extended._retrace_functions_for_each_device = (
len(extended.worker_devices) > 1)
distribution_strategy_replica_gauge.get_cell("num_replicas").set(
self.num_replicas_in_sync)
except: # pylint: disable=bare-except
# Default for the case where extended.worker_devices can't return
# a sensible value.
extended._retrace_functions_for_each_device = True
# Below are the dicts of axis(int) -> `tf.function`.
self._mean_reduce_helper_fns = {}
self._reduce_sum_fns = {}
# Whether this strategy is designed to work with `ClusterCoordinator`.
self._should_use_with_coordinator = False
@property
def extended(self):
"""`tf.distribute.StrategyExtended` with additional methods."""
return self._extended
@tf_contextlib.contextmanager
def _scale_loss_for_estimator_enabled(self):
"""Scope which sets a flag used for scaling losses in optimizer.
Yields:
`_scale_loss_for_estimator_enabled` is a context manager with a
side effect, but doesn't return a value.
"""
self._scale_loss_for_estimator = True
try:
yield
finally:
self._scale_loss_for_estimator = False
# pylint: disable=line-too-long
def scope(self):
"""Context manager to make the strategy current and distribute variables.
This method returns a context manager, and is used as follows:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> # Variable created inside scope:
>>> with strategy.scope():
... mirrored_variable = tf.Variable(1.)
>>> mirrored_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> # Variable created outside scope:
>>> regular_variable = tf.Variable(1.)
>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
_What happens when Strategy.scope is entered?_
* `strategy` is installed in the global context as the "current" strategy.
Inside this scope, `tf.distribute.get_strategy()` will now return this
strategy. Outside this scope, it returns the default no-op strategy.
* Entering the scope also enters the "cross-replica context". See
`tf.distribute.StrategyExtended` for an explanation on cross-replica and
replica contexts.
* Variable creation inside `scope` is intercepted by the strategy. Each
strategy defines how it wants to affect the variable creation. Sync
strategies like `MirroredStrategy`, `TPUStrategy` and
`MultiWorkerMiroredStrategy` create variables replicated on each replica,
whereas `ParameterServerStrategy` creates variables on the parameter
servers. This is done using a custom `tf.variable_creator_scope`.
* In some strategies, a default device scope may also be entered: in
`MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
entered on each worker.
Note: Entering a scope does not automatically distribute a computation, except
in the case of high level training framework like keras `model.fit`. If
you're not using `model.fit`, you
need to use `strategy.run` API to explicitly distribute that computation.
See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
_What should be in scope and what should be outside?_
There are a number of requirements on what needs to happen inside the scope.
However, in places where we have information about which strategy is in use,
we often enter the scope for the user, so they don't have to do it
explicitly (i.e. calling those either inside or outside the scope is OK).
* Anything that creates variables that should be distributed variables
must be called in a `strategy.scope`. This can be accomplished either by
directly calling the variable creating function within the scope context,
or by relying on another API like `strategy.run` or `keras.Model.fit` to
automatically enter it for you. Any variable that is created outside scope
will not be distributed and may have performance implications. Some common
objects that create variables in TF are Models, Optimizers, Metrics. Such
objects should always be initialized in the scope, and any functions
that may lazily create variables (e.g., `Model.__call__()`, tracing a
`tf.function`, etc.) should similarly be called within scope. Another
source of variable creation can be a checkpoint restore - when variables
are created lazily. Note that any variable created inside a strategy
captures the strategy information. So reading and writing to these
variables outside the `strategy.scope` can also work seamlessly, without
the user having to enter the scope.
* Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
require to be in a strategy's scope, enter the scope automatically, which
means when using those APIs you don't need to explicitly enter the scope
yourself.
* When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
object captures the scope information. When high level training framework
methods such as `model.compile`, `model.fit`, etc. are then called, the
captured scope will be automatically entered, and the associated strategy
will be used to distribute the training etc. See a detailed example in
[distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
WARNING: Simply calling `model(..)` does not automatically enter the
captured scope -- only high level training framework APIs support this
behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
and `model.save` can all be called inside or outside the scope.
* The following can be either inside or outside the scope:
* Creating the input datasets
* Defining `tf.function`s that represent your training step
* Saving APIs such as `tf.saved_model.save`. Loading creates variables,
so that should go inside the scope if you want to train the model in a
distributed way.
* Checkpoint saving. As mentioned above - `checkpoint.restore` may
sometimes need to be inside scope if it creates variables.
Returns:
A context manager.
"""
return self._extended._scope(self) # pylint: disable=protected-access
# pylint: enable=line-too-long
@doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended`
def colocate_vars_with(self, colocate_with_variable):
"""DEPRECATED: use extended.colocate_vars_with() instead."""
return self._extended.colocate_vars_with(colocate_with_variable)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_dataset_iterator(self, dataset):
"""DEPRECATED TF 1.x ONLY."""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""DEPRECATED TF 1.x ONLY."""
if replication_mode != InputReplicationMode.PER_WORKER:
raise ValueError(
"Input replication mode not supported: %r" % replication_mode)
with self.scope():
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
input_fn, replication_mode=replication_mode)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def experimental_run(self, fn, input_iterator=None):
"""DEPRECATED TF 1.x ONLY."""
with self.scope():
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.run(fn, args=args)
def experimental_distribute_dataset(self, dataset, options=None):
# pylint: disable=line-too-long
"""Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
The returned `tf.distribute.DistributedDataset` can be iterated over
similar to regular datasets.
NOTE: The user cannot add any more transformations to a
`tf.distribute.DistributedDataset`. You can only create an iterator or
examine the `tf.TypeSpec` of the data generated by it. See API docs of
`tf.distribute.DistributedDataset` to learn more.
The following is an example:
>>> global_batch_size = 2
>>> # Passing the devices is optional.
... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
>>> # Create a dataset
... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
>>> # Distribute that dataset
... dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> @tf.function
... def replica_fn(input):
... return input*2
>>> result = []
>>> # Iterate over the `tf.distribute.DistributedDataset`
... for x in dist_dataset:
... # process dataset elements
... result.append(strategy.run(replica_fn, args=(x,)))
>>> print(result)
[PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
}]
Three key actions happening under the hood of this method are batching,
sharding, and prefetching.
In the code snippet above, `dataset` is batched by `global_batch_size`, and
calling `experimental_distribute_dataset` on it rebatches `dataset` to a
new batch size that is equal to the global batch size divided by the number
of replicas in sync. We iterate through it using a Pythonic for loop.
`x` is a `tf.distribute.DistributedValues` containing data for all replicas,
and each replica gets data of the new batch size.
`tf.distribute.Strategy.run` will take care of feeding the right per-replica
data in `x` to the right `replica_fn` executed on each replica.
Sharding contains autosharding across multiple workers and within every
worker. First, in multi-worker distributed training (i.e. when you use
`tf.distribute.experimental.MultiWorkerMirroredStrategy`
or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
workers means that each worker is assigned a subset of the entire dataset
(if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
ensure that at each step, a global batch size of non-overlapping dataset
elements will be processed by each worker. Autosharding has a couple of
different options that can be specified using
`tf.data.experimental.DistributeOptions`. Then, sharding within each worker
means the method will split the data among all the worker devices (if more
than one a present). This will happen regardless of multi-worker
autosharding.
Note: for autosharding across multiple workers, the default mode is
`tf.data.experimental.AutoShardPolicy.AUTO`. This mode
will attempt to shard the input dataset by files if the dataset is
being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
`tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
where each of the workers will read the entire dataset and only process the
shard assigned to it. However, if you have less than one input file per
worker, we suggest that you disable dataset autosharding across workers by
setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
`tf.data.experimental.AutoShardPolicy.OFF`.
By default, this method adds a prefetch transformation at the end of the
user provided `tf.data.Dataset` instance. The argument to the prefetch
transformation which is `buffer_size` is equal to the number of replicas in
sync.
If the above batch splitting and dataset sharding logic is undesirable,
please use
`tf.distribute.Strategy.distribute_datasets_from_function`
instead, which does not do any automatic batching or sharding for you.
Note: If you are using TPUStrategy, the order in which the data is processed
by the workers when using
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.distribute_datasets_from_function` is
not guaranteed. This is typically required if you are using
`tf.distribute` to scale prediction. You can however insert an index for
each element in the batch and order outputs accordingly. Refer to [this
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
for an example of how to order outputs.
Note: Stateful dataset transformations are currently not supported with
`tf.distribute.experimental_distribute_dataset` or
`tf.distribute.distribute_datasets_from_function`. Any stateful
ops that the dataset may have are currently ignored. For example, if your
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
then you have a dataset graph that depends on state (i.e the random seed) on
the local machine where the python process is being executed.
For a tutorial on more usage and properties of this method, refer to the
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
Args:
dataset: `tf.data.Dataset` that will be sharded across all replicas using
the rules stated above.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A `tf.distribute.DistributedDataset`.
"""
# pylint: enable=line-too-long
return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
def distribute_datasets_from_function(self, dataset_fn, options=None):
# pylint: disable=line-too-long
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
The argument `dataset_fn` that users pass in is an input function that has a
`tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
instance. It is expected that the returned dataset from `dataset_fn` is
already batched by per-replica batch size (i.e. global batch size divided by
the number of replicas in sync) and sharded.
`tf.distribute.Strategy.distribute_datasets_from_function` does
not batch or shard the `tf.data.Dataset` instance
returned from the input function. `dataset_fn` will be called on the CPU
device of each of the workers and each generates a dataset where every
replica on that worker will dequeue one batch of inputs (i.e. if a worker
has two replicas, two batches will be dequeued from the `Dataset` every
step).
This method can be used for several purposes. First, it allows you to
specify your own batching and sharding logic. (In contrast,
`tf.distribute.experimental_distribute_dataset` does batching and sharding
for you.) For example, where
`experimental_distribute_dataset` is unable to shard the input files, this
method might be used to manually shard the dataset (avoiding the slow
fallback behavior in `experimental_distribute_dataset`). In cases where the
dataset is infinite, this sharding can be done by creating dataset replicas
that differ only in their random seed.
The `dataset_fn` should take an `tf.distribute.InputContext` instance where
information about batching and input replication can be accessed.
You can use `element_spec` property of the
`tf.distribute.DistributedDataset` returned by this API to query the
`tf.TypeSpec` of the elements returned by the iterator. This can be used to
set the `input_signature` property of a `tf.function`. Follow
`tf.distribute.DistributedDataset.element_spec` to see an example.
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
per-replica batch size, unlike `experimental_distribute_dataset`, which uses
the global batch size. This may be computed using
`input_context.get_per_replica_batch_size`.
Note: If you are using TPUStrategy, the order in which the data is processed
by the workers when using
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.distribute_datasets_from_function` is
not guaranteed. This is typically required if you are using
`tf.distribute` to scale prediction. You can however insert an index for
each element in the batch and order outputs accordingly. Refer to [this
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
for an example of how to order outputs.
Note: Stateful dataset transformations are currently not supported with
`tf.distribute.experimental_distribute_dataset` or
`tf.distribute.distribute_datasets_from_function`. Any stateful
ops that the dataset may have are currently ignored. For example, if your
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
then you have a dataset graph that depends on state (i.e the random seed) on
the local machine where the python process is being executed.
For a tutorial on more usage and properties of this method, refer to the
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
Args:
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
returning a `tf.data.Dataset`.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A `tf.distribute.DistributedDataset`.
"""
# pylint: enable=line-too-long
return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access
dataset_fn, options)
# TODO(b/162776748): Remove deprecated symbol.
@doc_controls.do_not_doc_inheritable
@deprecation.deprecated(None, "rename to distribute_datasets_from_function")
def experimental_distribute_datasets_from_function(self,
dataset_fn,
options=None):
return self.distribute_datasets_from_function(dataset_fn, options)
def run(self, fn, args=(), kwargs=None, options=None):
"""Invokes `fn` on each replica, with the given arguments.
This method is the primary way to distribute your computation with a
tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
have `tf.distribute.DistributedValues`, such as those produced by a
`tf.distribute.DistributedDataset` from
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.distribute_datasets_from_function`,
when `fn` is executed on a particular replica, it will be executed with the
component of `tf.distribute.DistributedValues` that correspond to that
replica.
`fn` is invoked under a replica context. `fn` may call
`tf.distribute.get_replica_context()` to access members such as
`all_reduce`. Please see the module-level docstring of tf.distribute for the
concept of replica context.
All arguments in `args` or `kwargs` can be a nested structure of tensors,
e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
the `fn` invoked on each replica. Or `args` or `kwargs` can be
`tf.distribute.DistributedValues` containing tensors or composite tensors,
i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
will get the component of a `tf.distribute.DistributedValues` corresponding
to its replica. Note that arbitrary Python values that are not of the types
above are not supported.
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
whether eager execution is enabled, `fn` may be called one or more times. If
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
called inside a `tf.function` (eager execution is disabled inside a
`tf.function` by default), `fn` is called once per replica to generate a
Tensorflow graph, which will then be reused for execution with new inputs.
Otherwise, if eager execution is enabled, `fn` will be called once per
replica every step just like regular python code.
Example usage:
1. Constant tensor input.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> tensor_input = tf.constant(3.0)
>>> @tf.function
... def replica_fn(input):
... return input*2.0
>>> result = strategy.run(replica_fn, args=(tensor_input,))
>>> result
PerReplica:{
0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
}
2. DistributedValues input.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> @tf.function
... def run():
... def value_fn(value_context):
... return value_context.num_replicas_in_sync
... distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
... def replica_fn2(input):
... return input*2
... return strategy.run(replica_fn2, args=(distributed_values,))
>>> result = run()
>>> result
<tf.Tensor: shape=(), dtype=int32, numpy=4>
3. Use `tf.distribute.ReplicaContext` to allreduce values.
>>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
>>> @tf.function
... def run():
... def value_fn(value_context):
... return tf.constant(value_context.replica_id_in_sync_group)
... distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
... def replica_fn(input):
... return tf.distribute.get_replica_context().all_reduce("sum", input)
... return strategy.run(replica_fn, args=(distributed_values,))
>>> result = run()
>>> result
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
Args:
fn: The function to run on each replica.
args: Optional positional arguments to `fn`. Its element can be a tensor,
a nested structure of tensors or a `tf.distribute.DistributedValues`.
kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
a nested structure of tensors or a `tf.distribute.DistributedValues`.
options: An optional instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `tf.distribute.DistributedValues`, `Tensor`
objects, or `Tensor`s (for example, if running on a single replica).
"""
del options
if not isinstance(args, (list, tuple)):
raise ValueError(
"positional args must be a list or tuple, got {}".format(type(args)))
with self.scope():
# tf.distribute supports Eager functions, so AutoGraph should not be
# applied when the caller is also in Eager mode.
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value, axis):
"""Reduce `value` across replicas and return result on current device.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def step_fn():
... i = tf.distribute.get_replica_context().replica_id_in_sync_group
... return tf.identity(i)
>>>
>>> per_replica_result = strategy.run(step_fn)
>>> total = strategy.reduce("SUM", per_replica_result, axis=None)
>>> total
<tf.Tensor: shape=(), dtype=int32, numpy=1>
To see how this would look with multiple replicas, consider the same
example with MirroredStrategy with 2 GPUs:
```python
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
def step_fn():
i = tf.distribute.get_replica_context().replica_id_in_sync_group
return tf.identity(i)
per_replica_result = strategy.run(step_fn)
# Check devices on which per replica result is:
strategy.experimental_local_results(per_replica_result)[0].device
# /job:localhost/replica:0/task:0/device:GPU:0
strategy.experimental_local_results(per_replica_result)[1].device
# /job:localhost/replica:0/task:0/device:GPU:1
total = strategy.reduce("SUM", per_replica_result, axis=None)
# Check device on which reduced result is:
total.device
# /job:localhost/replica:0/task:0/device:CPU:0
```
This API is typically used for aggregating the results returned from
different replicas, for reporting etc. For example, loss computed from
different replicas can be averaged using this API before printing.
Note: The result is copied to the "current" device - which would typically
be the CPU of the worker on which the program is running. For `TPUStrategy`,
it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
this is CPU of each worker.
There are a number of different tf.distribute APIs for reducing values
across replicas:
* `tf.distribute.ReplicaContext.all_reduce`: This differs from
`Strategy.reduce` in that it is for replica context and does
not copy the results to the host device. `all_reduce` should be typically
used for reductions inside the training step such as gradients.
* `tf.distribute.StrategyExtended.reduce_to` and
`tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
advanced versions of `Strategy.reduce` as they allow customizing the
destination of the result. They are also called in cross replica context.
_What should axis be?_
Given a per-replica value returned by `run`, say a
per-example loss, the batch will be divided across all the replicas. This
function allows you to aggregate across replicas and optionally also across
batch elements by specifying the axis parameter accordingly.
For example, if you have a global batch size of 8 and 2
replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
`[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
This is useful when each replica is computing a scalar or some other value
that doesn't have a "batch" dimension (like a gradient or loss).
```
strategy.reduce("sum", per_replica_result, axis=None)
```
Sometimes, you will want to aggregate across both the global batch _and_
all replicas. You can get this behavior by specifying the batch
dimension as the `axis`, typically `axis=0`. In this case it would return a
scalar `0+1+2+3+4+5+6+7`.
```
strategy.reduce("sum", per_replica_result, axis=0)
```
If there is a last partial batch, you will need to specify an axis so
that the resulting shape is consistent across replicas. So if the last
batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
would get a shape mismatch unless you specify `axis=0`. If you specify
`tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
denominator of 6. Contrast this with computing `reduce_mean` to get a
scalar value on each replica and this function to average those means,
which will weigh some values `1/8` and others `1/4`.
Args:
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
`Strategy.run`, to be combined into a single tensor. It can also be a
regular tensor when used with `OneDeviceStrategy` or default strategy.
axis: specifies the dimension to reduce along within each
replica's tensor. Should typically be set to the batch dimension, or
`None` to only reduce across replicas (e.g. if the tensor has no batch
dimension).
Returns:
A `Tensor`.
"""
# TODO(josh11b): support `value` being a nest.
_require_cross_replica_or_default_context_extended(self._extended)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if axis is None:
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op == reduce_util.ReduceOp.SUM:
def reduce_sum(v):
return math_ops.reduce_sum(v, axis=axis)
if eager_context.executing_eagerly():
# As some strategies (e.g. TPUStrategy) doesn't support pure eager
# execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
# run from eager mode. Cache the tf.function by `axis` to avoid the
# same function to be traced again.
if axis not in self._reduce_sum_fns:
def reduce_sum_fn(v):
return self.run(reduce_sum, args=(v,))
self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn)
value = self._reduce_sum_fns[axis](value)
else:
value = self.run(reduce_sum, args=(value,))
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op != reduce_util.ReduceOp.MEAN:
raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
"not: %r" % reduce_op)
# TODO(josh11b): Support list/tuple and tensor axis values.
if not isinstance(axis, six.integer_types):
raise TypeError("Expected `axis` to be an integer not: %r" % axis)
def mean_reduce_helper(v, axis=axis):
"""Computes the numerator and denominator on each replica."""
numer = math_ops.reduce_sum(v, axis=axis)
if v.shape.rank is not None:
# Note(joshl): We support axis < 0 to be consistent with the
# tf.math.reduce_* operations.
if axis < 0:
if axis + v.shape.rank < 0:
raise ValueError(
"`axis` = %r out of range for `value` with rank %d" %
(axis, v.shape.rank))
axis += v.shape.rank
elif axis >= v.shape.rank:
raise ValueError(
"`axis` = %r out of range for `value` with rank %d" %
(axis, v.shape.rank))
# TF v2 returns `None` for unknown dimensions and an integer for
# known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
# or tensor_shape.Dimension(integer). `dimension_value` hides this
# difference, always returning `None` or an integer.
dim = tensor_shape.dimension_value(v.shape[axis])
if dim is not None:
# By returning a python value in the static shape case, we can
# maybe get a fast path for reducing the denominator.
# TODO(b/151871486): Remove array_ops.identity after we fallback to
# simple reduction if inputs are all on CPU.
return numer, array_ops.identity(
constant_op.constant(dim, dtype=dtypes.int64))
elif axis < 0:
axis = axis + array_ops.rank(v)
# TODO(b/151871486): Remove array_ops.identity after we fallback to simple
# reduction if inputs are all on CPU.
denom = array_ops.identity(
array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
# TODO(josh11b): Should we cast denom to v.dtype here instead of after the
# reduce is complete?
return numer, denom
if eager_context.executing_eagerly():
# As some strategies (e.g. TPUStrategy) doesn't support pure eager
# execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
# be run from eager mode. Cache the tf.function by `axis` to avoid the
# same function to be traced again.
if axis not in self._mean_reduce_helper_fns:
def mean_reduce_fn(v):
return self.run(mean_reduce_helper, args=(v,))
self._mean_reduce_helper_fns[axis] = def_function.function(
mean_reduce_fn)
numer, denom = self._mean_reduce_helper_fns[axis](value)
else:
numer, denom = self.run(mean_reduce_helper, args=(value,))
# TODO(josh11b): Should batch reduce here instead of doing two.
numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access
denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access
denom = math_ops.cast(denom, numer.dtype)
return math_ops.truediv(numer, denom)
@doc_controls.do_not_doc_inheritable # DEPRECATED
def unwrap(self, value):
"""Returns the list of all local per-replica values contained in `value`.
DEPRECATED: Please use `experimental_local_results` instead.
Note: This only returns values on the workers initiated by this client.
When using a `tf.distribute.Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `experimental_run()`,
`extended.call_for_each_replica()`, or a variable created in `scope`.
Returns:
A tuple of values contained in `value`. If `value` represents a single
value, this returns `(value,).`
"""
return self._extended._local_results(value) # pylint: disable=protected-access
def experimental_local_results(self, value):
"""Returns the list of all local per-replica values contained in `value`.
Note: This only returns values on the worker initiated by this client.
When using a `tf.distribute.Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `experimental_run()`, `run(), or a variable
created in `scope`.
Returns:
A tuple of values contained in `value` where ith element corresponds to
ith replica. If `value` represents a single value, this returns
`(value,).`
"""
return self._extended._local_results(value) # pylint: disable=protected-access
@doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only
def group(self, value, name=None):
"""Shortcut for `tf.group(self.experimental_local_results(value))`."""
return self._extended._group(value, name) # pylint: disable=protected-access
@property
def num_replicas_in_sync(self):
"""Returns number of replicas over which gradients are aggregated."""
return self._extended._num_replicas_in_sync # pylint: disable=protected-access
@doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string
def configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""DEPRECATED: use `update_config_proto` instead.
Configures the strategy class.
DEPRECATED: This method's functionality has been split into the strategy
constructor and `update_config_proto`. In the future, we will allow passing
cluster and config_proto to the constructor to configure the strategy. And
`update_config_proto` can be used to update the config_proto based on the
specific strategy.
"""
return self._extended._configure( # pylint: disable=protected-access
session_config, cluster_spec, task_type, task_id)
@doc_controls.do_not_generate_docs # DEPRECATED
def update_config_proto(self, config_proto):
"""DEPRECATED TF 1.x ONLY."""
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
def __deepcopy__(self, memo):
# First do a regular deepcopy of `self`.
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
# One little fix-up: we want `result._extended` to reference `result`
# instead of `self`.
result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
return result
def __copy__(self):
raise RuntimeError("Must only deepcopy DistributionStrategy.")
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker `tf.distribute` strategy such as
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
`tf.distribute.TPUStrategy()`, there is a
`tf.distribute.cluster_resolver.ClusterResolver` associated with the
strategy used, and such an instance is returned by this property.
Strategies that intend to have an associated
`tf.distribute.cluster_resolver.ClusterResolver` must set the
relevant attribute, or override this property; otherwise, `None` is returned
by default. Those strategies should also provide information regarding what
is returned by this property.
Single-worker strategies usually do not have a
`tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
property will return `None`.
The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
user needs to access information such as the cluster spec, task type or task
id. For example,
```python
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"],
'ps': ["localhost:34567"]
},
'task': {'type': 'worker', 'index': 0}
})
# This implicitly uses TF_CONFIG for the cluster and current task info.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
...
if strategy.cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. Since we set this
# as a worker above, this block will run on this particular instance.
elif strategy.cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. Since we
# set this as a worker above, this block will not run on this particular
# instance.
```
For more information, please see
`tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
Returns:
The cluster resolver associated with this strategy. Returns `None` if a
cluster resolver is not applicable or available in this strategy.
"""
if hasattr(self.extended, "_cluster_resolver"):
return self.extended._cluster_resolver # pylint: disable=protected-access
return None
@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
class Strategy(StrategyBase):
__doc__ = StrategyBase.__doc__
def experimental_distribute_values_from_function(self, value_fn):
"""Generates `tf.distribute.DistributedValues` from `value_fn`.
This function is to generate `tf.distribute.DistributedValues` to pass
into `run`, `reduce`, or other methods that take
distributed values when not using datasets.
Args:
value_fn: The function to run to generate values. It is called for
each replica with `tf.distribute.ValueContext` as the sole argument. It
must return a Tensor or a type that can be converted to a Tensor.
Returns:
A `tf.distribute.DistributedValues` containing a value for each replica.
Example usage:
1. Return constant value per replica:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def value_fn(ctx):
... return tf.constant(1.)
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
2. Distribute values in array based on replica_id:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> array_value = np.array([3., 2., 1.])
>>> def value_fn(ctx):
... return array_value[ctx.replica_id_in_sync_group]
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(3.0, 2.0)
3. Specify values using num_replicas_in_sync:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def value_fn(ctx):
... return ctx.num_replicas_in_sync
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(2, 2)
4. Place values on devices and distribute:
```
strategy = tf.distribute.TPUStrategy()
worker_devices = strategy.extended.worker_devices
multiple_values = []
for i in range(strategy.num_replicas_in_sync):
with tf.device(worker_devices[i]):
multiple_values.append(tf.constant(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = strategy.
experimental_distribute_values_from_function(
value_fn)
```
"""
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
value_fn)
def gather(self, value, axis):
# pylint: disable=line-too-long, protected-access
"""Gather `value` across replicas along `axis` to the current device.
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
object `value`, this API gathers and concatenates `value` across replicas
along the `axis`-th dimension. The result is copied to the "current" device,
which would typically be the CPU of the worker on which the program is
running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
each worker.
This API can only be called in the cross-replica context. For a counterpart
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
Note: For all strategies except `tf.distribute.TPUStrategy`, the input
`value` on different replicas must have the same rank, and their shapes must
be the same in all dimensions except the `axis`-th dimension. In other
words, their shapes cannot be different in a dimension `d` where `d` does
not equal to the `axis` argument. For example, given a
`tf.distribute.DistributedValues` with component tensors of shape
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
all tensors must have exactly the same rank and same shape.
Note: Given a `tf.distribute.DistributedValues` `value`, its component
tensors must have a non-zero rank. Otherwise, consider using
`tf.expand_dims` before gathering them.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
>>> @tf.function
... def run():
... return strategy.gather(distributed_values, axis=0)
>>> run()
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
array([[1],
[2],
[1],
[2]], dtype=int32)>
Consider the following example for more combinations:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
>>> @tf.function
... def run(axis):
... return strategy.gather(distributed_values, axis=axis)
>>> axis=0
>>> run(axis)
<tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
>>> axis=1
>>> run(axis)
<tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
>>> axis=2
>>> run(axis)
<tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
Args:
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
`Strategy.run`, to be combined into a single tensor. It can also be a
regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
default strategy. The tensors that constitute the DistributedValues
can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
Returns:
A `Tensor` that's the concatenation of `value` across replicas along
`axis` dimension.
"""
# pylint: enable=line-too-long
error_message = ("tf.distribute.Strategy.gather method requires "
"cross-replica context, use "
"get_replica_context().all_gather() instead")
_require_cross_replica_or_default_context_extended(self._extended,
error_message)
dst = device_util.current(
) or self._extended._default_device or "/device:CPU:0"
if isinstance(value, ops.IndexedSlices):
raise NotImplementedError("gather does not support IndexedSlices")
return self._extended._local_results(
self._extended._gather_to(value, dst, axis))[0]
# TF v1.x version has additional deprecated APIs
@tf_export(v1=["distribute.Strategy"])
class StrategyV1(StrategyBase):
"""A list of devices with a state & compute distribution policy.
See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
for overview and examples.
Note: Not all `tf.distribute.Strategy` implementations currently support
TensorFlow's partitioned variables (where a single variable is split across
multiple devices) at this time.
"""
def make_dataset_iterator(self, dataset):
"""Makes an iterator for input provided via `dataset`.
DEPRECATED: This method is not available in TF 2.x.
Data from the given dataset will be distributed evenly across all the
compute replicas. We will assume that the input dataset is batched by the
global batch size. With this assumption, we will make a best effort to
divide each batch across all the replicas (one or more workers).
If this effort fails, an error will be thrown, and the user should instead
use `make_input_fn_iterator` which provides more control to the user, and
does not try to divide a batch across replicas.
The user could also use `make_input_fn_iterator` if they want to
customize which input is fed to which replica/worker etc.
Args:
dataset: `tf.data.Dataset` that will be distributed evenly across all
replicas.
Returns:
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""Returns an iterator split across replicas created from an input function.
DEPRECATED: This method is not available in TF 2.x.
The `input_fn` should take an `tf.distribute.InputContext` object where
information about batching and input sharding can be accessed:
```
def input_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with strategy.scope():
iterator = strategy.make_input_fn_iterator(input_fn)
replica_results = strategy.experimental_run(replica_fn, iterator)
```
The `tf.data.Dataset` returned by `input_fn` should have a per-replica
batch size, which may be computed using
`input_context.get_per_replica_batch_size`.
Args:
input_fn: A function taking a `tf.distribute.InputContext` object and
returning a `tf.data.Dataset`.
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
Only `PER_WORKER` is supported currently, which means there will be
a single call to `input_fn` per worker. Replicas will dequeue from the
local `tf.data.Dataset` on their worker.
Returns:
An iterator object that should first be `.initialize()`-ed. It may then
either be passed to `strategy.experimental_run()` or you can
`iterator.get_next()` to get the next value to pass to
`strategy.extended.call_for_each_replica()`.
"""
return super(StrategyV1, self).make_input_fn_iterator(
input_fn, replication_mode)
def experimental_make_numpy_dataset(self, numpy_input, session=None):
"""Makes a tf.data.Dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Note that you will likely need to use
tf.distribute.Strategy.experimental_distribute_dataset
with the returned dataset to further distribute it with the strategy.
Example:
```
numpy_input = np.ones([10], dtype=np.float32)
dataset = strategy.experimental_make_numpy_dataset(numpy_input)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
```
Args:
numpy_input: A nest of NumPy input arrays that will be converted into a
dataset. Note that lists of Numpy arrays are stacked, as that is normal
`tf.data.Dataset` behavior.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
return self.extended.experimental_make_numpy_dataset(
numpy_input, session=session)
def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
DEPRECATED: This method is not available in TF 2.x. Please switch
to using `run` instead.
When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica.
Each replica will take a single, different input from the inputs provided by
one `get_next` call on the input iterator.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
used, and whether eager execution is enabled, `fn` may be called one or more
times (once for each replica).
Args:
fn: The function to run. The inputs to the function must match the outputs
of `input_iterator.get_next()`. The output must be a `tf.nest` of
`Tensor`s.
input_iterator: (Optional) input iterator from which the inputs are taken.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `PerReplica` (if the values are unsynchronized),
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
single replica).
"""
return super(StrategyV1, self).experimental_run(
fn, input_iterator)
def reduce(self, reduce_op, value, axis=None):
return super(StrategyV1, self).reduce(reduce_op, value, axis)
reduce.__doc__ = StrategyBase.reduce.__doc__
def update_config_proto(self, config_proto):
"""Returns a copy of `config_proto` modified for use with this strategy.
DEPRECATED: This method is not available in TF 2.x.
The updated config has something needed to run a strategy, e.g.
configuration to run collective ops, or device filters to improve
distributed training performance.
Args:
config_proto: a `tf.ConfigProto` object.
Returns:
The updated copy of the `config_proto`.
"""
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
# instead descend from StrategyExtendedV1.
@tf_export("distribute.StrategyExtended", v1=[])
class StrategyExtendedV2(object):
"""Additional APIs for algorithms that need to be distribution-aware.
Note: For most usage of `tf.distribute.Strategy`, there should be no need to
call these methods, since TensorFlow libraries (such as optimizers) already
call these methods when needed on your behalf.
Some common use cases of functions on this page:
* _Locality_
`tf.distribute.DistributedValues` can have the same _locality_ as a
_distributed variable_, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to `tf.distribute.StrategyExtended.update` to update the
value of a variable. You may use
`tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
`tf.distribute.StrategyExtended.batch_reduce_to`.
* _How to update a distributed variable_
A distributed variable is variables created on multiple devices. As discussed
in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
mirrored variable and SyncOnRead variable are two examples. The standard
pattern for updating distributed variables is to:
1. In your function passed to `tf.distribute.Strategy.run`,
compute a list of (update, variable) pairs. For example, the update might
be a gradient of the loss with respect to the variable.
2. Switch to cross-replica mode by calling
`tf.distribute.get_replica_context().merge_call()` with the updates and
variables as arguments.
3. Call
`tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
(for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
(for a list of variables) to sum the updates.
4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
its value.
Steps 2 through 4 are done automatically by class
`tf.keras.optimizers.Optimizer` if you call its
`tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
In fact, a higher-level solution to update a distributed variable is by
calling `assign` on the variable as you would do to a regular `tf.Variable`.
You can call the method in both _replica context_ and _cross-replica context_.
For a _mirrored variable_, calling `assign` in _replica context_ requires you
to specify the `aggregation` type in the variable constructor. In that case,
the context switching and sync described in steps 2 through 4 are handled for
you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
you can only assign a single value or assign values from another mirrored
variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
variable_, in _replica context_, you can simply call `assign` on it and no
aggregation happens under the hood. In _cross-replica context_, you can only
assign a single value to a SyncOnRead variable. One example case is restoring
from a checkpoint: if the `aggregation` type of the variable is
`tf.VariableAggregation.SUM`, it is assumed that replica values were added
before checkpointing, so at the time of restoring, the value is divided by
the number of replicas and then assigned to each replica; if the `aggregation`
type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
directly.
"""
def __init__(self, container_strategy):
self._container_strategy_weakref = weakref.ref(container_strategy)
self._default_device = None
# This property is used to determine if we should set drop_remainder=True
# when creating Datasets from numpy array inputs.
self._require_static_shapes = False
def _container_strategy(self):
"""Get the containing `tf.distribute.Strategy`.
This should not generally be needed except when creating a new
`ReplicaContext` and to validate that the caller is in the correct
`scope()`.
Returns:
The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
"""
container_strategy = self._container_strategy_weakref()
assert container_strategy is not None
return container_strategy
def _scope(self, strategy):
"""Implementation of tf.distribute.Strategy.scope()."""
def creator_with_resource_vars(next_creator, **kwargs):
"""Variable creator to use in `_CurrentDistributionContext`."""
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["distribute_strategy"] = strategy
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
# dereferencing a `Tensor` that is without a `name`. We still need to
# propagate the metadata it's holding.
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
elif isinstance(kwargs["initial_value"],
trackable.CheckpointInitialValueCallable):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
elif (isinstance(kwargs["initial_value"], functools.partial) and
isinstance(kwargs["initial_value"].func,
trackable.CheckpointInitialValueCallable)):
# Some libraries (e.g, Keras) create partial function out of initializer
# to bind shape/dtype, for example:
# initial_val = functools.partial(initializer, shape, dtype=dtype)
# Therefore to get the restore_uid we need to examine the "func" of
# the partial function.
checkpoint_restore_uid = kwargs[
"initial_value"].func.checkpoint_position.restore_uid
else:
checkpoint_restore_uid = None
created = self._create_variable(next_creator, **kwargs)
if checkpoint_restore_uid is not None:
# pylint: disable=protected-access
# Let the checkpointing infrastructure know that the variable was
# already restored so it doesn't waste memory loading the value again.
# In this case of CheckpointInitialValueCallable this may already be
# done by the final variable creator, but it doesn't hurt to do it
# again.
created._maybe_initialize_trackable()
created._update_uid = checkpoint_restore_uid
# pylint: enable=protected-access
return created
def distributed_getter(getter, *args, **kwargs):
if not self._allow_variable_partition():
if kwargs.pop("partitioner", None) is not None:
tf_logging.log_first_n(
tf_logging.WARN, "Partitioned variables are disabled when using "
"current tf.distribute.Strategy.", 1)
return getter(*args, **kwargs)
return _CurrentDistributionContext(
strategy,
variable_scope.variable_creator_scope(creator_with_resource_vars),
variable_scope.variable_scope(
variable_scope.get_variable_scope(),
custom_getter=distributed_getter), self._default_device)
def _allow_variable_partition(self):
return False
def _create_variable(self, next_creator, **kwargs):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
def variable_created_in_scope(self, v):
"""Tests whether `v` was created while this strategy scope was active.
Variables created inside the strategy scope are "owned" by it:
>>> strategy = tf.distribute.MirroredStrategy()
>>> with strategy.scope():
... v = tf.Variable(1.)
>>> strategy.extended.variable_created_in_scope(v)
True
Variables created outside the strategy are not owned by it:
>>> strategy = tf.distribute.MirroredStrategy()
>>> v = tf.Variable(1.)
>>> strategy.extended.variable_created_in_scope(v)
False
Args:
v: A `tf.Variable` instance.
Returns:
True if `v` was created inside the scope, False if not.
"""
return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access
def colocate_vars_with(self, colocate_with_variable):
"""Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it
should only be used when creating variables (some implementations
work by changing variable creation, others work by using a
tf.compat.v1.colocate_with() scope).
This may only be used inside `self.scope()`.
Example usage:
```
with strategy.scope():
var1 = tf.Variable(...)
with strategy.extended.colocate_vars_with(var1):
# var2 and var3 will be created on the same device(s) as var1
var2 = tf.Variable(...)
var3 = tf.Variable(...)
def fn(v1, v2, v3):
# operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `var1` is on, `var2` and `var3` will be there
# too.
strategy.extended.update(var1, fn, args=(var2, var3))
```
Args:
colocate_with_variable: A variable created in this strategy's `scope()`.
Variables created while in the returned context manager will be on the
same set of devices as `colocate_with_variable`.
Returns:
A context manager.
"""
def create_colocated_variable(next_creator, **kwargs):
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["colocate_with"] = colocate_with_variable
return next_creator(**kwargs)
_require_strategy_scope_extended(self)
self._validate_colocate_with_variable(colocate_with_variable)
return variable_scope.variable_creator_scope(create_colocated_variable)
def _validate_colocate_with_variable(self, colocate_with_variable):
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
pass
def _make_dataset_iterator(self, dataset):
raise NotImplementedError("must be implemented in descendants")
def _make_input_fn_iterator(self, input_fn, replication_mode):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_dataset(self, dataset, options):
raise NotImplementedError("must be implemented in descendants")
def _distribute_datasets_from_function(self, dataset_fn, options):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_values_from_function(self, value_fn):
raise NotImplementedError("must be implemented in descendants")
def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy.
dst = device_util.current() or self._default_device or "/device:CPU:0"
return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
def reduce_to(self, reduce_op, value, destinations, options=None):
"""Combine (via e.g. sum or mean) values across replicas.
`reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
variables. It supports both dense values and `tf.IndexedSlices`.
This API currently can only be called in cross-replica context. Other
variants to reduce values across replicas are:
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
this API.
* `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
in replica context. It supports both batched and non-batched all-reduce.
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
to the host in cross-replica context.
`destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
also pass in a `Tensor`, and the destinations will be the device of that
tensor. For all-reduce, pass the same to `value` and `destinations`.
It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
that works for all `tf.distribute.Strategy`.
>>> @tf.function
... def step_fn(var):
...
... def merge_fn(strategy, value, var):
... # All-reduce the value. Note that `value` here is a
... # `tf.distribute.DistributedValues`.
... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
... value, destinations=var)
... strategy.extended.update(var, lambda var, value: var.assign(value),
... args=(reduced,))
...
... value = tf.identity(1.)
... tf.distribute.get_replica_context().merge_call(merge_fn,
... args=(value, var))
>>>
>>> def run(strategy):
... with strategy.scope():
... v = tf.Variable(0.)
... strategy.run(step_fn, args=(v,))
... return v
>>>
>>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
}
>>> run(tf.distribute.experimental.CentralStorageStrategy(
... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
>>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
Args:
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to reduce to. To perform an all-reduce, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, and this method doesn't update the
variable.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A tensor or value reduced to `destinations`.
"""
if options is None:
options = collective_util.Options()
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
assert (reduce_op == reduce_util.ReduceOp.SUM or
reduce_op == reduce_util.ReduceOp.MEAN)
return self._reduce_to(reduce_op, value, destinations, options)
def _reduce_to(self, reduce_op, value, destinations, options):
raise NotImplementedError("must be implemented in descendants")
def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
"""Combine multiple `reduce_to` calls into one for faster execution.
Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
It's more efficient than reduce each value separately.
This API currently can only be called in cross-replica context. Other
variants to reduce values across replicas are:
* `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
this API.
* `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
in replica context. It supports both batched and non-batched all-reduce.
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
to the host in cross-replica context.
See `reduce_to` for more information.
>>> @tf.function
... def step_fn(var):
...
... def merge_fn(strategy, value, var):
... # All-reduce the value. Note that `value` here is a
... # `tf.distribute.DistributedValues`.
... reduced = strategy.extended.batch_reduce_to(
... tf.distribute.ReduceOp.SUM, [(value, var)])[0]
... strategy.extended.update(var, lambda var, value: var.assign(value),
... args=(reduced,))
...
... value = tf.identity(1.)
... tf.distribute.get_replica_context().merge_call(merge_fn,
... args=(value, var))
>>>
>>> def run(strategy):
... with strategy.scope():
... v = tf.Variable(0.)
... strategy.run(step_fn, args=(v,))
... return v
>>>
>>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
}
>>> run(tf.distribute.experimental.CentralStorageStrategy(
... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
>>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
Args:
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value_destination_pairs: a sequence of (value, destinations) pairs. See
`tf.distribute.Strategy.reduce_to` for descriptions.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A list of reduced values, one per pair in `value_destination_pairs`.
"""
if options is None:
options = collective_util.Options()
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
return [
self.reduce_to(reduce_op, t, destinations=v, options=options)
for t, v in value_destination_pairs
]
def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
"""All-reduce `value` across all replicas so that all get the final result.
If `value` is a nested structure of tensors, all-reduces of these tensors
will be batched when possible. `options` can be set to hint the batching
behavior.
This API must be called in a replica context.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined.
value: Value to be reduced. A tensor or a nested structure of tensors.
options: A `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor.
Returns:
A tensor or a nested strucutre of tensors with the reduced values. The
structure is the same as `value`.
"""
if options is None:
options = collective_util.Options()
replica_context = distribution_strategy_context.get_replica_context()
assert replica_context, (
"`StrategyExtended._replica_ctx_all_reduce` must be called in"
" a replica context")
def merge_fn(_, flat_value):
return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
options)
reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
return nest.pack_sequence_as(value, reduced)
def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True):
"""Run `fn` with `args` and `kwargs` to update `var`."""
# This method is called by ReplicaContext.update. Strategies who'd like to
# remove merge_call in this path should override this method.
replica_context = distribution_strategy_context.get_replica_context()
if not replica_context:
raise ValueError("`StrategyExtended._replica_ctx_update` must be called "
"in a replica context.")
def merge_fn(_, *merged_args, **merged_kwargs):
return self.update(var, fn, merged_args, merged_kwargs, group=group)
return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
def _gather_to(self, value, destinations, axis, options=None):
"""Gather `value` across replicas along axis-th dimension to `destinations`.
`gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
object, along `axis`-th dimension. It supports only dense tensors but NOT
sparse tensor. This API can only be called in cross-replica context.
Args:
value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to reduce to. To perform an all-gather, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, and this method doesn't update the
variable.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A tensor or value gathered to `destinations`.
"""
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
if options is None:
options = collective_util.Options()
return self._gather_to_implementation(value, destinations, axis, options)
def _gather_to_implementation(self, value, destinations, axis, options):
raise NotImplementedError("_gather_to must be implemented in descendants")
def _batch_gather_to(self, value_destination_pairs, axis, options=None):
_require_cross_replica_or_default_context_extended(self)
if options is None:
options = collective_util.Options()
return [
self._gather_to(t, destinations=v, axis=axis, options=options)
for t, v in value_destination_pairs
]
def update(self, var, fn, args=(), kwargs=None, group=True):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
`tf.distribute.StrategyExtended.update` takes a distributed variable `var`
to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
applies `fn` to each component variable of `var` and passes corresponding
values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
per-replica values. If they contain mirrored values, they will be unwrapped
before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
a mirrored DistributedValues where each component contains the value to be
added to this mirrored variable `var`. Calling `update` will call
`assign_add` on each component variable of `var` with the corresponding
tensor value on that device.
Example usage:
```python
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
devices
with strategy.scope():
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def update_fn(v):
return v.assign(1.0)
result = strategy.extended.update(v, update_fn)
# result is
# Mirrored:{
# 0: tf.Tensor(1.0, shape=(), dtype=float32),
# 1: tf.Tensor(1.0, shape=(), dtype=float32)
# }
```
If `var` is mirrored across multiple devices, then this method implements
logic as following:
```python
results = {}
for device, v in var:
with tf.device(device):
# args and kwargs will be unwrapped if they are mirrored.
results[device] = fn(v, *args, **kwargs)
return merged(results)
```
Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
`var`.
Args:
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
args: Tuple or list. Additional positional arguments to pass to `fn()`.
kwargs: Dict with keyword arguments to pass to `fn()`.
group: Boolean. Defaults to True. If False, the return value will be
unwrapped.
Returns:
By default, the merged return value of `fn` across all replicas. The
merged result has dependencies to make sure that if it is evaluated at
all, the side effects (updates) will happen on every replica. If instead
"group=False" is specified, this function will return a nest of lists
where each list has an element per replica, and the caller is responsible
for ensuring all elements are executed.
"""
# TODO(b/178944108): Update the documentation to relfect the fact that
# `update` can be called in a replica context.
if kwargs is None:
kwargs = {}
replica_context = distribution_strategy_context.get_replica_context()
# pylint: disable=protected-access
if (replica_context is None or replica_context is
distribution_strategy_context._get_default_replica_context()):
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update(var, fn, args, kwargs, group)
else:
return self._replica_ctx_update(
var, fn, args=args, kwargs=kwargs, group=group)
def _update(self, var, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")
def _local_results(self, val):
"""Returns local results per replica as a tuple."""
if isinstance(val, values.DistributedValues):
return val._values # pylint: disable=protected-access
if nest.is_nested(val):
replica_values = []
def get_values(x, index):
if isinstance(x, values.DistributedValues):
return x._values[index] # pylint: disable=protected-access
return x
for i in range(len(self.worker_devices)):
replica_values.append(
nest.map_structure(
lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop
val))
return tuple(replica_values)
return (val,)
def value_container(self, value):
"""Returns the container that this per-replica `value` belongs to.
Args:
value: A value returned by `run()` or a variable created in `scope()`.
Returns:
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
`value in experimental_local_results(value_container(value))` will
always be true.
"""
raise NotImplementedError("must be implemented in descendants")
def _group(self, value, name=None):
"""Implementation of `group`."""
value = nest.flatten(self._local_results(value))
if len(value) != 1 or name is not None:
return control_flow_ops.group(value, name=name)
# Special handling for the common case of one op.
v, = value
if hasattr(v, "op"):
v = v.op
return v
@property
def experimental_require_static_shapes(self):
"""Returns `True` if static shape is required; `False` otherwise."""
return self._require_static_shapes
@property
def _num_replicas_in_sync(self):
"""Returns number of replicas over which gradients are aggregated."""
raise NotImplementedError("must be implemented in descendants")
@property
def worker_devices(self):
"""Returns the tuple of all devices used to for compute replica execution.
"""
# TODO(josh11b): More docstring
raise NotImplementedError("must be implemented in descendants")
@property
def parameter_devices(self):
"""Returns the tuple of all devices used to place variables."""
# TODO(josh11b): More docstring
raise NotImplementedError("must be implemented in descendants")
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
"""Configures the strategy class."""
del session_config, cluster_spec, task_type, task_id
def _update_config_proto(self, config_proto):
return copy.deepcopy(config_proto)
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings.
Multi-worker training refers to the setup where the training is
distributed across multiple workers, as opposed to the case where
only a local process performs the training. This function is
used by higher-level APIs such as Keras' `model.fit()` to infer
for example whether or not a distribute coordinator should be run,
and thus TensorFlow servers should be started for communication
with other servers in the cluster, or whether or not saving/restoring
checkpoints is relevant for preemption fault tolerance.
Subclasses should override this to provide whether the strategy is
currently in multi-worker setup.
Experimental. Signature and implementation are subject to change.
"""
raise NotImplementedError("must be implemented in descendants")
@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
class StrategyExtendedV1(StrategyExtendedV2):
__doc__ = StrategyExtendedV2.__doc__
def experimental_make_numpy_dataset(self, numpy_input, session=None):
"""Makes a dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked, as
that is normal `tf.data.Dataset` behavior.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
_require_cross_replica_or_default_context_extended(self)
return self._experimental_make_numpy_dataset(numpy_input, session=session)
def _experimental_make_numpy_dataset(self, numpy_input, session):
raise NotImplementedError("must be implemented in descendants")
def broadcast_to(self, tensor, destinations):
"""Mirror a tensor on one device to all worker devices.
Args:
tensor: A Tensor value to broadcast.
destinations: A mirrored variable or device string specifying the
destination devices to copy `tensor` to.
Returns:
A value mirrored to `destinations` devices.
"""
assert destinations is not None # from old strategy.broadcast()
# TODO(josh11b): More docstring
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
return self._broadcast_to(tensor, destinations)
def _broadcast_to(self, tensor, destinations):
raise NotImplementedError("must be implemented in descendants")
def experimental_run_steps_on_iterator(self,
fn,
iterator,
iterations=1,
initial_loop_values=None):
"""DEPRECATED: please use `run` instead.
Run `fn` with input from `iterator` for `iterations` times.
This method can be used to run a step function for training a number of
times using input from a dataset.
Args:
fn: function to run using this distribution strategy. The function must
have the following signature: `def fn(context, inputs)`. `context` is an
instance of `MultiStepContext` that will be passed when `fn` is run.
`context` can be used to specify the outputs to be returned from `fn`
by calling `context.set_last_step_output`. It can also be used to
capture non tensor outputs by `context.set_non_tensor_output`. See
`MultiStepContext` documentation for more information. `inputs` will
have same type/structure as `iterator.get_next()`. Typically, `fn`
will use `call_for_each_replica` method of the strategy to distribute
the computation over multiple replicas.
iterator: Iterator of a dataset that represents the input for `fn`. The
caller is responsible for initializing the iterator as needed.
iterations: (Optional) Number of iterations that `fn` should be run.
Defaults to 1.
initial_loop_values: (Optional) Initial values to be passed into the
loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
initial_loop_values argument when we have a mechanism to infer the
outputs of `fn`.
Returns:
Returns the `MultiStepContext` object which has the following properties,
among other things:
- run_op: An op that runs `fn` `iterations` times.
- last_step_outputs: A dictionary containing tensors set using
`context.set_last_step_output`. Evaluating this returns the value of
the tensors after the last iteration.
- non_tensor_outputs: A dictionary containing anything that was set by
`fn` by calling `context.set_non_tensor_output`.
"""
_require_cross_replica_or_default_context_extended(self)
with self._container_strategy().scope():
return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
initial_loop_values)
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values):
raise NotImplementedError("must be implemented in descendants")
def call_for_each_replica(self, fn, args=(), kwargs=None):
"""Run `fn` once per replica.
`fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`.
`merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
```python
# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
# sum the values across replicas
return sum(distribution.experimental_local_results(three_plus_replica_id))
# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
replica_ctx = tf.get_replica_context()
v = three + replica_ctx.replica_id_in_sync_group
# Computes the sum of the `v` values across all replicas.
s = replica_ctx.merge_call(merge_fn, args=(v,))
return s + v
with distribution.scope():
# in "cross-replica" context
...
merged_results = distribution.run(fn, args=[3])
# merged_results has the values from every replica execution of `fn`.
# This statement prints a list:
print(distribution.experimental_local_results(merged_results))
```
Args:
fn: function to run (will be run once per replica).
args: Tuple or list with positional arguments for `fn`.
kwargs: Dict with keyword arguments for `fn`.
Returns:
Merged return value of `fn` across all replicas.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
with self._container_strategy().scope():
return self._call_for_each_replica(fn, args, kwargs)
def _call_for_each_replica(self, fn, args, kwargs):
raise NotImplementedError("must be implemented in descendants")
def read_var(self, v):
"""Reads the value of a variable.
Returns the aggregate value of a replica-local variable, or the
(read-only) value of any other variable.
Args:
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
Returns:
A tensor representing the value of `v`, aggregated across replicas if
necessary.
"""
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(
self, colocate_with, fn, args=(), kwargs=None, group=True):
"""Runs `fn(*args, **kwargs)` on `colocate_with` devices.
Used to update non-slot variables.
DEPRECATED: TF 1.x ONLY.
Args:
colocate_with: Devices returned by `non_slot_devices()`.
fn: Function to execute.
args: Tuple or list. Positional arguments to pass to `fn()`.
kwargs: Dict with keyword arguments to pass to `fn()`.
group: Boolean. Defaults to True. If False, the return value will be
unwrapped.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
fn = autograph.tf_convert(
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update_non_slot(colocate_with, fn, args, kwargs, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")
def non_slot_devices(self, var_list):
"""Device(s) for non-slot variables.
DEPRECATED: TF 1.x ONLY.
This method returns non-slot devices where non-slot variables are placed.
Users can create non-slot variables on these devices by using a block:
```python
with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
...
```
Args:
var_list: The list of variables being optimized, needed with the
default `tf.distribute.Strategy`.
Returns:
A sequence of devices for non-slot variables.
"""
raise NotImplementedError("must be implemented in descendants")
def _use_merge_call(self):
"""Whether to use merge-calls inside the distributed strategy."""
return True
@property
def experimental_between_graph(self):
"""Whether the strategy uses between-graph replication or not.
This is expected to return a constant value that will not be changed
throughout its life cycle.
"""
raise NotImplementedError("must be implemented in descendants")
@property
def experimental_should_init(self):
"""Whether initialization is needed."""
raise NotImplementedError("must be implemented in descendants")
@property
def should_checkpoint(self):
"""Whether checkpointing is needed."""
raise NotImplementedError("must be implemented in descendants")
@property
def should_save_summary(self):
"""Whether saving summaries is needed."""
raise NotImplementedError("must be implemented in descendants")
# A note about the difference between the context managers
# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
# (defined above) used by `tf.distribute.Strategy.scope()`:
#
# * a ReplicaContext is only present during a `run()`
# call (except during a `merge_run` call) and in such a scope it
# will be returned by calls to `get_replica_context()`. Implementers of new
# Strategy descendants will frequently also need to
# define a descendant of ReplicaContext, and are responsible for
# entering and exiting this context.
#
# * Strategy.scope() sets up a variable_creator scope that
# changes variable creation calls (e.g. to make mirrored
# variables). This is intended as an outer scope that users enter once
# around their model creation and graph definition. There is no
# anticipated need to define descendants of _CurrentDistributionContext.
# It sets the current Strategy for purposes of
# `get_strategy()` and `has_strategy()`
# and switches the thread mode to a "cross-replica context".
class ReplicaContextBase(object):
"""A class with a collection of APIs that can be called in a replica context.
You can use `tf.distribute.get_replica_context` to get an instance of
`ReplicaContext`, which can only be called inside the function passed to
`tf.distribute.Strategy.run`.
>>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
>>> def func():
... replica_context = tf.distribute.get_replica_context()
... return replica_context.replica_id_in_sync_group
>>> strategy.run(func)
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
"""
def __init__(self, strategy, replica_id_in_sync_group):
"""Creates a ReplicaContext.
Args:
strategy: A `tf.distribute.Strategy`.
replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
integer whenever possible to avoid issues with nested `tf.function`. It
accepts a `Tensor` only to be compatible with `tpu.replicate`.
"""
self._strategy = strategy
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
self)
if not (replica_id_in_sync_group is None or
tensor_util.is_tf_type(replica_id_in_sync_group) or
isinstance(replica_id_in_sync_group, int)):
raise ValueError(
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
self._replica_id_in_sync_group = replica_id_in_sync_group
# We need this check because TPUContext extends from ReplicaContext and
# does not pass a strategy object since it is used by TPUEstimator.
if strategy:
self._local_replica_id = strategy.extended._get_local_replica_id(
replica_id_in_sync_group)
self._summary_recording_distribution_strategy = None
@doc_controls.do_not_generate_docs
def __enter__(self):
_push_per_thread_mode(self._thread_context)
def replica_id_is_zero():
return math_ops.equal(self.replica_id_in_sync_group,
constant_op.constant(0))
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
self._summary_recording_distribution_strategy = (
summary_state.is_recording_distribution_strategy)
summary_state.is_recording_distribution_strategy = replica_id_is_zero
@doc_controls.do_not_generate_docs
def __exit__(self, exception_type, exception_value, traceback):
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
summary_state.is_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
_pop_per_thread_mode()
def merge_call(self, merge_fn, args=(), kwargs=None):
"""Merge args across replicas and run `merge_fn` in a cross-replica context.
This allows communication and coordination when there are multiple calls
to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
See `tf.distribute.Strategy.run` for an explanation.
If not inside a distributed scope, this is equivalent to:
```
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
```
Args:
merge_fn: Function that joins arguments from threads that are given as
PerReplica. It accepts `tf.distribute.Strategy` object as
the first argument.
args: List or tuple with positional per-thread arguments for `merge_fn`.
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
Returns:
The return value of `merge_fn`, except for `PerReplica` values which are
unpacked.
"""
require_replica_context(self)
if kwargs is None:
kwargs = {}
merge_fn = autograph.tf_convert(
merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._merge_call(merge_fn, args, kwargs)
def _merge_call(self, merge_fn, args, kwargs):
"""Default implementation for single replica."""
_push_per_thread_mode( # thread-local, so not needed with multiple threads
distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access
try:
return merge_fn(self._strategy, *args, **kwargs)
finally:
_pop_per_thread_mode()
@property
def num_replicas_in_sync(self):
"""Returns number of replicas that are kept in sync."""
return self._strategy.num_replicas_in_sync
@property
def replica_id_in_sync_group(self):
"""Returns the id of the replica.
This identifies the replica among all replicas that are kept in sync. The
value of the replica id can range from 0 to
`tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
for low-level operations such as collective_permute.
Returns:
a `Tensor`.
"""
# It's important to prefer making the Tensor at call time whenever possible.
# Keeping Tensors in global states doesn't work well with nested
# tf.function, since it's possible that the tensor is generated in one func
# graph, and gets captured by another, which will result in a subtle "An op
# outside of the function building code is being passed a Graph tensor"
# error. Making the tensor at call time to ensure it is the same graph where
# it's used. However to be compatible with tpu.replicate(),
# self._replica_id_in_sync_group can also be a Tensor.
if tensor_util.is_tf_type(self._replica_id_in_sync_group):
return self._replica_id_in_sync_group
return constant_op.constant(
self._replica_id_in_sync_group,
dtypes.int32,
name="replica_id_in_sync_group")
@property
def _replica_id(self):
"""This is the local replica id in a given sync group."""
return self._local_replica_id
@property
def strategy(self):
"""The current `tf.distribute.Strategy` object."""
return self._strategy
@property
@deprecation.deprecated(None, "Please avoid relying on devices property.")
def devices(self):
"""Returns the devices this replica is to be executed on, as a tuple of strings.
NOTE: For `tf.distribute.MirroredStrategy` and
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
nested
list of device strings, e.g, [["GPU:0"]].
"""
require_replica_context(self)
return (device_util.current(),)
def all_reduce(self, reduce_op, value, options=None):
"""All-reduces `value` across all replicas.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def step_fn():
... ctx = tf.distribute.get_replica_context()
... value = tf.identity(1.)
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
>>> strategy.experimental_local_results(strategy.run(step_fn))
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
It supports batched operations. You can pass a list of values and it
attempts to batch them when possible. You can also specify `options`
to indicate the desired batching behavior, e.g. batch the values into
multiple packs so that they can better overlap with computations.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def step_fn():
... ctx = tf.distribute.get_replica_context()
... value1 = tf.identity(1.)
... value2 = tf.identity(2.)
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
>>> strategy.experimental_local_results(strategy.run(step_fn))
([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
[<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
Note that all replicas need to participate in the all-reduce, otherwise this
operation hangs. Note that if there're multiple all-reduces, they need to
execute in the same order on all replicas. Dispatching all-reduce based on
conditions is usually error-prone.
This API currently can only be called in the replica context. Other
variants to reduce values across replicas are:
* `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
in the cross-replica context.
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
all-reduce API in the cross-replica context.
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
to the host in cross-replica context.
Args:
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
The structure and the shapes of the `tf.Tensor` need to be same on all
replicas.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A nested structure of `tf.Tensor` with the reduced values. The structure
is the same as `value`.
"""
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if options is None:
options = collective_util.Options()
def batch_all_reduce(strategy, *value_flat):
return strategy.extended.batch_reduce_to(
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
options)
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_reduce, args=xs)
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
else:
# TODO(cjfj): Implement gradients for other reductions.
reduced = nest.pack_sequence_as(
value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
return nest.map_structure(array_ops.prevent_gradient, reduced)
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
# all-reduce. It would return a function returning the result of reducing `t`
# across all replicas. The caller would wait to call this function until they
# needed the reduce result, allowing an efficient implementation:
# * With eager execution, the reduction could be performed asynchronously
# in the background, not blocking until the result was needed.
# * When constructing a graph, it could batch up all reduction requests up
# to that point that the first result is needed. Most likely this can be
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
@tf_export("distribute.ReplicaContext", v1=[])
class ReplicaContext(ReplicaContextBase):
__doc__ = ReplicaContextBase.__doc__
def all_gather(self, value, axis, options=None):
"""All-gathers `value` across all replicas along `axis`.
Note: An `all_gather` method can only be called in replica context. For
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
All replicas need to participate in the all-gather, otherwise this
operation hangs. So if `all_gather` is called in any replica, it must be
called in all replicas.
Note: If there are multiple `all_gather` calls, they need to be executed in
the same order on all replicas. Dispatching `all_gather` based on conditions
is usually error-prone.
For all strategies except `tf.distribute.TPUStrategy`, the input
`value` on different replicas must have the same rank, and their shapes must
be the same in all dimensions except the `axis`-th dimension. In other
words, their shapes cannot be different in a dimension `d` where `d` does
not equal to the `axis` argument. For example, given a
`tf.distribute.DistributedValues` with component tensors of shape
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
`all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
or `all_gather(..., axis=2, ...)`. However, with
`tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
same shape.
Note: The input `value` must have a non-zero rank. Otherwise, consider using
`tf.expand_dims` before gathering them.
You can pass in a single tensor to all-gather:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> @tf.function
... def gather_value():
... ctx = tf.distribute.get_replica_context()
... local_value = tf.constant([1, 2, 3])
... return ctx.all_gather(local_value, axis=0)
>>> result = strategy.run(gather_value)
>>> result
PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}
>>> strategy.experimental_local_results(result)
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>,
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>)
You can also pass in a nested structure of tensors to all-gather, say, a
list:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> @tf.function
... def gather_nest():
... ctx = tf.distribute.get_replica_context()
... value_1 = tf.constant([1, 2, 3])
... value_2 = tf.constant([[1, 2], [3, 4]])
... # all_gather a nest of `tf.distribute.DistributedValues`
... return ctx.all_gather([value_1, value_2], axis=0)
>>> result = strategy.run(gather_nest)
>>> result
[PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}, PerReplica:{
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>,
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>
}]
>>> strategy.experimental_local_results(result)
([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>],
[<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>])
What if you are all-gathering tensors with different shapes on different
replicas? Consider the following example with two replicas, where you have
`value` as a nested structure consisting of two items to all-gather, `a` and
`b`.
On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
Result for `all_gather` with `axis`=0 (on each of the replicas) is:
```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}```
Args:
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
or a `tf.distribute.DistributedValues` instance. The structure of the
`tf.Tensor` need to be same on all replicas. The underlying tensor
constructs can only be dense tensors with non-zero rank, NOT
`tf.IndexedSlices`.
axis: 0-D int32 Tensor. Dimension along which to gather.
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor. See
`tf.distribute.experimental.CommunicationOptions` for details of the
options.
Returns:
A nested structure of `tf.Tensor` with the gathered values. The structure
is the same as `value`.
"""
for v in nest.flatten(value):
if isinstance(v, ops.IndexedSlices):
raise NotImplementedError("all_gather does not support IndexedSlices")
if options is None:
options = collective_util.Options()
def batch_all_gather(strategy, *value_flat):
return strategy.extended._batch_gather_to( # pylint: disable=protected-access
[(v, _batch_reduce_destination(v)) for v in value_flat], axis,
options)
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_gather, args=xs)
def grad(*dy_s):
grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
new_grads = []
for i, grad in enumerate(grads):
input_shape = array_ops.shape(xs[i])
axis_dim = array_ops.reshape(input_shape[axis], [1])
with ops.control_dependencies([array_ops.identity(grads)]):
d = self.all_gather(axis_dim, axis=0)
begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
end_dim = begin_dim + array_ops.shape(xs[i])[axis]
new_grad = array_ops.gather(
grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
new_grads.append(new_grad)
return new_grads
return ys, grad
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
def _update(self, var, fn, args=(), kwargs=None, group=True):
"""Run `fn` to update `var` with `args` and `kwargs` in replica context.
`tf.distribute.ReplicaContext.update` takes a (distributed) variable `var`
to be updated, an update function `fn`, and `args` and `kwargs` for `fn`.
`fn` applies to each component variable of `var` with corresponding input
values from `args` and `kwargs`.
Example usage:
>>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'CPU:0']) # 2 replicas
>>> with strategy.scope():
... distributed_variable = tf.Variable(5.0)
>>> distributed_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0>
}
>>> def replica_fn(v):
... value = tf.identity(1.0)
... replica_context = tf.distribute.get_replica_context()
... update_fn = lambda var, value: var.assign(value)
... replica_context._update(v, update_fn, args=(value,))
>>> strategy.run(replica_fn, args=(distributed_variable,))
>>> distributed_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
This API must be called in a replica context.
Note that if `var` is a MirroredVariable (i.e., the type of variable created
under the scope of a synchronous strategy, and is synchronized on-write, see
`tf.VariableSynchronization` for more information) and `args`/`kwargs`
contains different values for different replicas, `var` will be dangerously
out of synchronization. Thus we recommend using `variable.assign(value)` as
long as you can, which under the hood aggregates the updates and guarantees
the synchronization. The case where you actually want this API instead of
`variable.assign(value)` is that before assigning `value` to the `variable`,
you'd like to conduct some pre-`assign` computation colocated with the
variable devices (i.e. where variables reside, for MirroredStrategy they are
the same as the compute device, for ParameterServerStrategy they refer to
parameter servers). E.g.,
```python
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
with strategy.scope():
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def replica_fn(inputs):
value = computation(inputs)
replica_context = tf.distribute.get_replica_context()
reduced_value = replica_context.all_reduce(value)
def update_fn(var, value):
# this computation will colocate with `var`'s device
updated_value = post_reduce_pre_update_computation(value)
var.assign(value)
replica_context._update(v, update_fn, args=(reduced_value,))
strategy.run(replica_fn, args=(inputs,))
```
This code snippet is consistent across all strategies. If you directly
compute and use `assign` in the replica context instead of wrapping it with
`update`, for strategies with fewer variable devices than compute devices
(e.g., parameter server strategy, usually), the
`post_reduce_pre_update_computation` will happen
N==number_of_compute_devices times which is less performant.
Args:
var: Variable, possibly distributed to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
args: Tuple or list. Additional positional arguments to pass to `fn()`.
kwargs: Dict with keyword arguments to pass to `fn()`.
group: Boolean. Defaults to True. Most strategies enter a merge_call to
conduct update in cross-replica context, and group=True guarantees updates
on all replicas is executed.
Returns:
The return value of `fn` for the local replica.
"""
if kwargs is None:
kwargs = {}
return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access
@tf_export(v1=["distribute.ReplicaContext"])
class ReplicaContextV1(ReplicaContextBase):
__doc__ = ReplicaContextBase.__doc__
def _batch_reduce_destination(x):
"""Returns the destinations for batch all-reduce."""
if isinstance(x, ops.Tensor):
# If this is a one device strategy.
return x.device
else:
return x
# ------------------------------------------------------------------------------
_creating_default_strategy_singleton = False
class _DefaultDistributionStrategyV1(StrategyV1):
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
def __init__(self):
if not _creating_default_strategy_singleton:
raise RuntimeError("Should only create a single instance of "
"_DefaultDistributionStrategy")
super(_DefaultDistributionStrategyV1,
self).__init__(_DefaultDistributionExtended(self))
def __deepcopy__(self, memo):
del memo
raise RuntimeError("Should only create a single instance of "
"_DefaultDistributionStrategy")
class _DefaultDistributionStrategy(Strategy):
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
def __init__(self):
if not _creating_default_strategy_singleton:
raise RuntimeError("Should only create a single instance of "
"_DefaultDistributionStrategy")
super(_DefaultDistributionStrategy, self).__init__(
_DefaultDistributionExtended(self))
def __deepcopy__(self, memo):
del memo
raise RuntimeError("Should only create a single instance of "
"_DefaultDistributionStrategy")
class _DefaultDistributionContext(object):
"""Context manager setting the default `tf.distribute.Strategy`."""
__slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
def __init__(self, strategy):
def creator(next_creator, **kwargs):
_require_strategy_scope_strategy(strategy)
return next_creator(**kwargs)
self._var_creator_scope = variable_scope.variable_creator_scope(creator)
self._strategy = strategy
self._nested_count = 0
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
if self._nested_count == 0:
self._var_creator_scope.__enter__()
self._nested_count += 1
return self._strategy
def __exit__(self, exception_type, exception_value, traceback):
self._nested_count -= 1
if self._nested_count == 0:
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
class _DefaultDistributionExtended(StrategyExtendedV1):
"""Implementation of _DefaultDistributionStrategy."""
def __init__(self, container_strategy):
super(_DefaultDistributionExtended, self).__init__(container_strategy)
self._retrace_functions_for_each_device = False
def _scope(self, strategy):
"""Context manager setting a variable creator and `self` as current."""
return _DefaultDistributionContext(strategy)
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_strategy_scope_extended(self)
return ops.colocate_with(colocate_with_variable)
def variable_created_in_scope(self, v):
return v._distribute_strategy is None # pylint: disable=protected-access
def _experimental_distribute_dataset(self, dataset, options):
return dataset
def _distribute_datasets_from_function(self, dataset_fn, options):
return dataset_fn(InputContext())
def _experimental_distribute_values_from_function(self, value_fn):
return value_fn(ValueContext())
def _make_dataset_iterator(self, dataset):
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
def _make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
dataset = input_fn(InputContext())
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
def _experimental_make_numpy_dataset(self, numpy_input, session):
numpy_flat = nest.flatten(numpy_input)
vars_flat = tuple(
variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
trainable=False, use_resource=True)
for i in numpy_flat
)
for v, i in zip(vars_flat, numpy_flat):
numpy_dataset.init_var_from_numpy(v, i, session)
vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
return dataset_ops.Dataset.from_tensor_slices(vars_nested)
def _broadcast_to(self, tensor, destinations):
if destinations is None:
return tensor
else:
raise NotImplementedError("TODO")
def _call_for_each_replica(self, fn, args, kwargs):
with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations, options):
# TODO(josh11b): Use destinations?
del reduce_op, destinations, options
return value
def _gather_to_implementation(self, value, destinations, axis, options):
del destinations, axis, options
return value
def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with UpdateContext(colocate_with):
result = fn(*args, **kwargs)
if should_group:
return result
else:
return nest.map_structure(self._local_results, result)
def read_var(self, replica_local_var):
return array_ops.identity(replica_local_var)
def _local_results(self, distributed_value):
return (distributed_value,)
def value_container(self, value):
return value
@property
def _num_replicas_in_sync(self):
return 1
@property
def worker_devices(self):
raise RuntimeError("worker_devices() method unsupported by default "
"tf.distribute.Strategy.")
@property
def parameter_devices(self):
raise RuntimeError("parameter_devices() method unsupported by default "
"tf.distribute.Strategy.")
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
# Default strategy doesn't indicate multi-worker training.
return False
@property
def should_checkpoint(self):
return True
@property
def should_save_summary(self):
return True
def _get_local_replica_id(self, replica_id_in_sync_group):
return replica_id_in_sync_group
def _get_replica_id_in_sync_group(self, replica_id):
return replica_id
# TODO(priyag): This should inherit from `InputIterator`, once dependency
# issues have been resolved.
class DefaultInputIterator(object):
"""Default implementation of `InputIterator` for default strategy."""
def __init__(self, dataset):
self._dataset = dataset
if eager_context.executing_eagerly():
self._iterator = dataset_ops.make_one_shot_iterator(dataset)
else:
self._iterator = dataset_ops.make_initializable_iterator(dataset)
def get_next(self):
return self._iterator.get_next()
def get_next_as_optional(self):
return self._iterator.get_next_as_optional()
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialize underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
if eager_context.executing_eagerly():
self._iterator = self._dataset.make_one_shot_iterator()
return []
else:
return [self._iterator.initializer]
@property
def initializer(self):
"""Returns a list of ops that initialize the iterator."""
return self.initialize()
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""Global and per-replica batching are equivalent for this strategy."""
return True
class _DefaultReplicaContext(ReplicaContext):
"""ReplicaContext for _DefaultDistributionStrategy."""
@property
def replica_id_in_sync_group(self):
# Return 0 instead of a constant tensor to avoid creating a new node for
# users who don't use distribution strategy.
return 0
# ------------------------------------------------------------------------------
# We haven't yet implemented deserialization for DistributedVariables.
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
# pylint: disable=protected-access
_original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
if distribution_strategy_context.has_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using a "
"tf.distribute.Strategy.")
else:
return _original_from_proto(v, import_scope=import_scope)
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
#-------------------------------------------------------------------------------
# Shorthand for some methods from distribution_strategy_context.
_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
_get_default_replica_mode = (
distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access
# ------------------------------------------------------------------------------
# Metrics to track which distribution strategy is being called
distribution_strategy_gauge = monitoring.StringGauge(
"/tensorflow/api/distribution_strategy",
"Gauge to track the type of distribution strategy used.", "TFVersion")
distribution_strategy_replica_gauge = monitoring.IntGauge(
"/tensorflow/api/distribution_strategy/replica",
"Gauge to track the number of replica each distribution strategy used.",
"CountType")