| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=line-too-long |
| """Library for running a computation across multiple devices. |
| |
| The intent of this library is that you can write an algorithm in a stylized way |
| and it will be usable with a variety of different `tf.distribute.Strategy` |
| implementations. Each descendant will implement a different strategy for |
| distributing the algorithm across multiple devices/machines. Furthermore, these |
| changes can be hidden inside the specific layers and other library classes that |
| need special treatment to run in a distributed setting, so that most users' |
| model definition code can run unchanged. The `tf.distribute.Strategy` API works |
| the same way with eager and graph execution. |
| |
| *Guides* |
| |
| * [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) |
| * [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) |
| |
| *Tutorials* |
| |
| * [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) |
| |
| The tutorials cover how to use `tf.distribute.Strategy` to do distributed |
| training with native Keras APIs, custom training loops, |
| and Estimator APIs. They also cover how to save/load model when using |
| `tf.distribute.Strategy`. |
| |
| *Glossary* |
| |
| * _Data parallelism_ is where we run multiple copies of the model |
| on different slices of the input data. This is in contrast to |
| _model parallelism_ where we divide up a single copy of a model |
| across multiple devices. |
| Note: we only support data parallelism for now, but |
| hope to add support for model parallelism in the future. |
| * A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that |
| TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple |
| devices on a single machine, or be connected to devices on multiple |
| machines. Devices used to run computations are called _worker devices_. |
| Devices used to store variables are _parameter devices_. For some strategies, |
| such as `tf.distribute.MirroredStrategy`, the worker and parameter devices |
| will be the same (see mirrored variables below). For others they will be |
| different. For example, `tf.distribute.experimental.CentralStorageStrategy` |
| puts the variables on a single device (which may be a worker device or may be |
| the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the |
| variables on separate machines called _parameter servers_ (see below). |
| * A _replica_ is one copy of the model, running on one slice of the |
| input data. Right now each replica is executed on its own |
| worker device, but once we add support for model parallelism |
| a replica may span multiple worker devices. |
| * A _host_ is the CPU device on a machine with worker devices, typically |
| used for running input pipelines. |
| * A _worker_ is defined to be the physical machine(s) containing the physical |
| devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A |
| worker may contain one or more replicas, but contains at least one |
| replica. Typically one worker will correspond to one machine, but in the case |
| of very large models with model parallelism, one worker may span multiple |
| machines. We typically run one input pipeline per worker, feeding all the |
| replicas on that worker. |
| * _Synchronous_, or more commonly _sync_, training is where the updates from |
| each replica are aggregated together before updating the model variables. This |
| is in contrast to _asynchronous_, or _async_ training, where each replica |
| updates the model variables independently. You may also have replicas |
| partitioned into groups which are in sync within each group but async between |
| groups. |
| * _Parameter servers_: These are machines that hold a single copy of |
| parameters/variables, used by some strategies (right now just |
| `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want |
| to operate on a variable retrieve it at the beginning of a step and send an |
| update to be applied at the end of the step. These can in principle support |
| either sync or async training, but right now we only have support for async |
| training with parameter servers. Compare to |
| `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables |
| on a single device on the same machine (and does sync training), and |
| `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices |
| (see below). |
| |
| * _Replica context_ vs. _Cross-replica context_ vs _Update context_ |
| |
| A _replica context_ applies |
| when you execute the computation function that was called with `strategy.run`. |
| Conceptually, you're in replica context when executing the computation |
| function that is being replicated. |
| |
| An _update context_ is entered in a `tf.distribute.StrategyExtended.update` |
| call. |
| |
| An _cross-replica context_ is entered when you enter a `strategy.scope`. This |
| is useful for calling `tf.distribute.Strategy` methods which operate across |
| the replicas (like `reduce_to()`). By default you start in a _replica context_ |
| (the "default single _replica context_") and then some methods can switch you |
| back and forth. |
| |
| * _Distributed value_: Distributed value is represented by the base class |
| `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful |
| to represent values on multiple devices, and it contains a map from replica id |
| to values. Two representative kinds of `tf.distribute.DistributedValues` are |
| "PerReplica" and "Mirrored" values. |
| |
| "PerReplica" values exist on the worker |
| devices, with a different value for each replica. They are produced by |
| iterating through a distributed dataset returned by |
| `tf.distribute.Strategy.experimental_distribute_dataset` and |
| `tf.distribute.Strategy.distribute_datasets_from_function`. They |
| are also the typical result returned by |
| `tf.distribute.Strategy.run`. |
| |
| "Mirrored" values are like "PerReplica" values, except we know that the value |
| on all replicas are the same. We can safely read a "Mirrored" value in a |
| cross-replica context by using the value on any replica. |
| |
| * _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple |
| replicas, like `strategy.run(fn, args=[w])` with an |
| argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will |
| have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. |
| `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on |
| device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return |
| values from `fn()`, which leads to one common object if the returned values |
| are the same object from every replica, or a `DistributedValues` object |
| otherwise. |
| |
| * _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating |
| multiple values into one value, like "sum" or "mean". If a strategy is doing |
| sync training, we will perform a reduction on the gradients to a parameter |
| from all replicas before applying the update. _All-reduce_ is an algorithm for |
| performing a reduction on values from multiple devices and making the result |
| available on all of those devices. |
| |
| * _Mirrored variables_: These are variables that are created on multiple |
| devices, where we keep the variables in sync by applying the same |
| updates to every copy. Mirrored variables are created with |
| `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. |
| Normally they are only used in synchronous training. |
| |
| * _SyncOnRead variables_ |
| |
| _SyncOnRead variables_ are created by |
| `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and |
| they are created on multiple devices. In replica context, each |
| component variable on the local replica can perform reads and writes without |
| synchronization with each other. When the |
| _SyncOnRead variable_ is read in cross-replica context, the values from |
| component variables are aggregated and returned. |
| |
| _SyncOnRead variables_ bring a lot of custom configuration difficulty to the |
| underlying logic, so we do not encourage users to instantiate and use |
| _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead |
| variables_ for use cases such as batch norm and metrics. For performance |
| reasons, we often don't need to keep these statistics in sync every step and |
| they can be accumulated on each replica independently. The only time we want |
| to sync them is reporting or checkpointing, which typically happens in |
| cross-replica context. _SyncOnRead variables_ are also often used by advanced |
| users who want to control when variable values are aggregated. For example, |
| users sometimes want to maintain gradients independently on each replica for a |
| couple of steps without aggregation. |
| |
| * _Distribute-aware layers_ |
| |
| Layers are generally called in a replica context, except when defining a |
| Keras functional model. `tf.distribute.in_cross_replica_context` will let you |
| determine which case you are in. If in a replica context, |
| the `tf.distribute.get_replica_context` function will return the default |
| replica context outside a strategy scope, `None` within a strategy scope, and |
| a `tf.distribute.ReplicaContext` object inside a strategy scope and within a |
| `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an |
| `all_reduce` method for aggregating across all replicas. |
| |
| |
| Note that we provide a default version of `tf.distribute.Strategy` that is |
| used when no other strategy is in scope, that provides the same API with |
| reasonable default behavior. |
| """ |
| # pylint: enable=line-too-long |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import copy |
| import enum # pylint: disable=g-bad-import-order |
| import functools |
| import threading |
| import weakref |
| |
| import six |
| |
| from tensorflow.python.autograph.core import ag_ctx as autograph_ctx |
| from tensorflow.python.autograph.impl import api as autograph |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import collective_util |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribution_strategy_context |
| from tensorflow.python.distribute import numpy_dataset |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.distribute import values |
| from tensorflow.python.eager import context as eager_context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import monitoring |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import summary_ops_v2 |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops.losses import losses_impl |
| from tensorflow.python.platform import tf_logging |
| from tensorflow.python.training.tracking import base as trackable |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_contextlib |
| from tensorflow.python.util.deprecation import deprecated |
| from tensorflow.python.util.tf_export import tf_export |
| from tensorflow.tools.docs import doc_controls |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Context tracking whether in a strategy.update() or .update_non_slot() call. |
| |
| |
| _update_replica_id = threading.local() |
| |
| |
| def get_update_replica_id(): |
| """Get the current device if in a `tf.distribute.Strategy.update()` call.""" |
| try: |
| return _update_replica_id.current |
| except AttributeError: |
| return None |
| |
| |
| class UpdateContext(object): |
| """Context manager when you are in `update()` or `update_non_slot()`.""" |
| |
| __slots__ = ["_replica_id", "_old_replica_id"] |
| |
| def __init__(self, replica_id): |
| self._replica_id = replica_id |
| self._old_replica_id = None |
| |
| def __enter__(self): |
| self._old_replica_id = get_update_replica_id() |
| _update_replica_id.current = self._replica_id |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| del exception_type, exception_value, traceback |
| _update_replica_id.current = self._old_replica_id |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Public utility functions. |
| |
| |
| @tf_export(v1=["distribute.get_loss_reduction"]) |
| def get_loss_reduction(): |
| """`tf.distribute.ReduceOp` corresponding to the last loss reduction. |
| |
| This is used to decide whether loss should be scaled in optimizer (used only |
| for estimator + v1 optimizer use case). |
| |
| Returns: |
| `tf.distribute.ReduceOp` corresponding to the last loss reduction for |
| estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. |
| """ |
| if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access |
| # If we are not in Estimator context then return 'SUM'. We do not need to |
| # scale loss in the optimizer. |
| return reduce_util.ReduceOp.SUM |
| last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access |
| if (last_reduction == losses_impl.Reduction.SUM or |
| last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM |
| return reduce_util.ReduceOp.SUM |
| return reduce_util.ReduceOp.MEAN |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal API for validating the current thread mode |
| |
| |
| def _require_cross_replica_or_default_context_extended(extended, |
| error_message=None): |
| """Verify in cross-replica context.""" |
| context = _get_per_thread_mode() |
| cross_replica = context.cross_replica_context |
| if cross_replica is not None and cross_replica.extended is extended: |
| return |
| if context is _get_default_replica_mode(): |
| return |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| # We have an error to report, figure out the right message. |
| if context.strategy is not strategy: |
| _wrong_strategy_scope(strategy, context) |
| assert cross_replica is None |
| if not error_message: |
| error_message = ("Method requires being in cross-replica context, use " |
| "get_replica_context().merge_call()") |
| raise RuntimeError(error_message) |
| |
| |
| def _wrong_strategy_scope(strategy, context): |
| # Figure out the right error message. |
| if not distribution_strategy_context.has_strategy(): |
| raise RuntimeError( |
| 'Need to be inside "with strategy.scope()" for %s' % |
| (strategy,)) |
| else: |
| raise RuntimeError( |
| "Mixing different tf.distribute.Strategy objects: %s is not %s" % |
| (context.strategy, strategy)) |
| |
| |
| def require_replica_context(replica_ctx): |
| """Verify in `replica_ctx` replica context.""" |
| context = _get_per_thread_mode() |
| if context.replica_context is replica_ctx: return |
| # We have an error to report, figure out the right message. |
| if context.replica_context is None: |
| raise RuntimeError("Need to be inside `call_for_each_replica()`") |
| if context.strategy is replica_ctx.strategy: |
| # Two different ReplicaContexts with the same tf.distribute.Strategy. |
| raise RuntimeError("Mismatching ReplicaContext.") |
| raise RuntimeError( |
| "Mismatching tf.distribute.Strategy objects: %s is not %s." % |
| (context.strategy, replica_ctx.strategy)) |
| |
| |
| def _require_strategy_scope_strategy(strategy): |
| """Verify in a `strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy is strategy: return |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| def _require_strategy_scope_extended(extended): |
| """Verify in a `distribution_strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy.extended is extended: return |
| # Report error. |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal context managers used to implement the DistributionStrategy |
| # base class |
| |
| |
| class _CurrentDistributionContext(object): |
| """Context manager setting the current `tf.distribute.Strategy`. |
| |
| Also: overrides the variable creator and optionally the current device. |
| """ |
| |
| def __init__(self, |
| strategy, |
| var_creator_scope, |
| var_scope=None, |
| default_device=None): |
| self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access |
| strategy) |
| self._var_creator_scope = var_creator_scope |
| self._var_scope = var_scope |
| if default_device: |
| self._device_scope = ops.device(default_device) |
| else: |
| self._device_scope = None |
| self._same_scope_again_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| _require_cross_replica_or_default_context_extended( |
| self._context.strategy.extended) |
| self._same_scope_again_count += 1 |
| else: |
| _push_per_thread_mode(self._context) |
| if self._var_scope: |
| self._var_scope.__enter__() |
| self._var_creator_scope.__enter__() |
| if self._device_scope: |
| self._device_scope.__enter__() |
| return self._context.strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| if self._same_scope_again_count > 0: |
| self._same_scope_again_count -= 1 |
| return |
| if self._device_scope: |
| try: |
| self._device_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Device scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| if self._var_scope: |
| try: |
| self._var_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| _pop_per_thread_mode() |
| |
| |
| # TODO(yuefengz): add more replication modes. |
| @tf_export("distribute.InputReplicationMode") |
| class InputReplicationMode(enum.Enum): |
| """Replication mode for input function. |
| |
| * `PER_WORKER`: The input function will be called on each worker |
| independently, creating as many input pipelines as number of workers. |
| Replicas will dequeue from the local Dataset on their worker. |
| `tf.distribute.Strategy` doesn't manage any state sharing between such |
| separate input pipelines. |
| * `PER_REPLICA`: The input function will be called on each replica separately. |
| `tf.distribute.Strategy` doesn't manage any state sharing between such |
| separate input pipelines. |
| """ |
| PER_WORKER = "PER_WORKER" |
| PER_REPLICA = "PER_REPLICA" |
| |
| |
| @tf_export("distribute.InputContext") |
| class InputContext(object): |
| """A class wrapping information needed by an input function. |
| |
| This is a context class that is passed to the user's input function and |
| contains information about the compute replicas and input pipelines. The |
| number of compute replicas (in sync training) helps compute the local batch |
| size from the desired global batch size for each replica. The input pipeline |
| information can be used to return a different subset of the input in each |
| replica (for e.g. shard the input pipeline, use a different input |
| source etc). |
| """ |
| |
| __slots__ = [ |
| "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" |
| ] |
| |
| def __init__(self, |
| num_input_pipelines=1, |
| input_pipeline_id=0, |
| num_replicas_in_sync=1): |
| """Initializes an InputContext object. |
| |
| Args: |
| num_input_pipelines: the number of input pipelines in a cluster. |
| input_pipeline_id: the current input pipeline id, should be an int in |
| [0,`num_input_pipelines`). |
| num_replicas_in_sync: the number of replicas that are in sync. |
| """ |
| self._num_input_pipelines = num_input_pipelines |
| self._input_pipeline_id = input_pipeline_id |
| self._num_replicas_in_sync = num_replicas_in_sync |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns the number of compute replicas in sync.""" |
| return self._num_replicas_in_sync |
| |
| @property |
| def input_pipeline_id(self): |
| """Returns the input pipeline ID.""" |
| return self._input_pipeline_id |
| |
| @property |
| def num_input_pipelines(self): |
| """Returns the number of input pipelines.""" |
| return self._num_input_pipelines |
| |
| def get_per_replica_batch_size(self, global_batch_size): |
| """Returns the per-replica batch size. |
| |
| Args: |
| global_batch_size: the global batch size which should be divisible by |
| `num_replicas_in_sync`. |
| |
| Returns: |
| the per-replica batch size. |
| |
| Raises: |
| ValueError: if `global_batch_size` not divisible by |
| `num_replicas_in_sync`. |
| """ |
| if global_batch_size % self._num_replicas_in_sync != 0: |
| raise ValueError("The `global_batch_size` %r is not divisible by " |
| "`num_replicas_in_sync` %r " % |
| (global_batch_size, self._num_replicas_in_sync)) |
| return global_batch_size // self._num_replicas_in_sync |
| |
| def __str__(self): |
| return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( |
| self.input_pipeline_id, self.num_input_pipelines) |
| |
| |
| @tf_export("distribute.experimental.ValueContext", v1=[]) |
| class ValueContext(object): |
| """A class wrapping information needed by a distribute function. |
| |
| This is a context class that is passed to the `value_fn` in |
| `strategy.experimental_distribute_values_from_function` and contains |
| information about the compute replicas. The `num_replicas_in_sync` and |
| `replica_id` can be used to customize the value on each replica. |
| |
| Example usage: |
| |
| 1. Directly constructed. |
| |
| >>> def value_fn(context): |
| ... return context.replica_id_in_sync_group/context.num_replicas_in_sync |
| >>> context = tf.distribute.experimental.ValueContext( |
| ... replica_id_in_sync_group=2, num_replicas_in_sync=4) |
| >>> per_replica_value = value_fn(context) |
| >>> per_replica_value |
| 0.5 |
| |
| 2. Passed in by `experimental_distribute_values_from_function`. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(value_context): |
| ... return value_context.num_replicas_in_sync |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (2, 2) |
| |
| """ |
| |
| __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] |
| |
| def __init__(self, |
| replica_id_in_sync_group=0, |
| num_replicas_in_sync=1): |
| """Initializes an ValueContext object. |
| |
| Args: |
| replica_id_in_sync_group: the current replica_id, should be an int in |
| [0,`num_replicas_in_sync`). |
| num_replicas_in_sync: the number of replicas that are in sync. |
| """ |
| self._replica_id_in_sync_group = replica_id_in_sync_group |
| self._num_replicas_in_sync = num_replicas_in_sync |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns the number of compute replicas in sync.""" |
| return self._num_replicas_in_sync |
| |
| @property |
| def replica_id_in_sync_group(self): |
| """Returns the replica ID.""" |
| return self._replica_id_in_sync_group |
| |
| def __str__(self): |
| return (("tf.distribute.ValueContext(replica id {}, " |
| " total replicas in sync: ""{})") |
| .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) |
| |
| |
| @tf_export("distribute.RunOptions") |
| class RunOptions( |
| collections.namedtuple("RunOptions", [ |
| "experimental_enable_dynamic_batch_size", |
| "experimental_bucketizing_dynamic_shape", |
| "experimental_xla_options", |
| ])): |
| """Run options for `strategy.run`. |
| |
| This can be used to hold some strategy specific configs. |
| |
| Attributes: |
| experimental_enable_dynamic_batch_size: Boolean. Only applies to |
| TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic |
| padder to support dynamic batch size for the inputs. Otherwise only static |
| shape inputs are allowed. |
| experimental_bucketizing_dynamic_shape: Boolean. Only applies to |
| TPUStrategy. Default to False. If True, TPUStrategy will automatic |
| bucketize inputs passed into `run` if the input shape is |
| dynamic. This is a performance optimization to reduce XLA recompilation, |
| which should not have impact on correctness. |
| experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to |
| TPUStrategy. Controls the XLA compiling options on TPUs. Default to None. |
| """ |
| |
| def __new__(cls, |
| experimental_enable_dynamic_batch_size=True, |
| experimental_bucketizing_dynamic_shape=False, |
| experimental_xla_options=None): |
| return super(RunOptions, |
| cls).__new__(cls, experimental_enable_dynamic_batch_size, |
| experimental_bucketizing_dynamic_shape, |
| experimental_xla_options) |
| |
| |
| @tf_export("distribute.InputOptions", v1=[]) |
| class InputOptions( |
| collections.namedtuple("InputOptions", [ |
| "experimental_fetch_to_device", |
| "experimental_replication_mode", |
| "experimental_place_dataset_on_device", |
| "experimental_per_replica_buffer_size", |
| ])): |
| """Run options for `experimental_distribute_dataset(s_from_function)`. |
| |
| This can be used to hold some strategy specific configs. |
| |
| ```python |
| # Setup TPUStrategy |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tf.config.experimental_connect_to_cluster(resolver) |
| tf.tpu.experimental.initialize_tpu_system(resolver) |
| strategy = tf.distribute.TPUStrategy(resolver) |
| |
| dataset = tf.data.Dataset.range(16) |
| distributed_dataset_on_host = ( |
| strategy.experimental_distribute_dataset( |
| dataset, |
| tf.distribute.InputOptions( |
| experimental_replication_mode= |
| experimental_replication_mode.PER_WORKER, |
| experimental_place_dataset_on_device=False, |
| experimental_per_replica_buffer_size=1))) |
| ``` |
| |
| Attributes: |
| experimental_fetch_to_device: Boolean. If True, dataset |
| elements will be prefetched to accelerator device memory. When False, |
| dataset elements are prefetched to host device memory. Must be False when |
| using TPUEmbedding API. experimental_fetch_to_device can only be used |
| with experimental_replication_mode=PER_WORKER. Default behavior is same as |
| setting it to True. |
| experimental_replication_mode: Replication mode for the input function. |
| Currently, the InputReplicationMode.PER_REPLICA is only supported with |
| tf.distribute.MirroredStrategy. |
| experimental_distribute_datasets_from_function. |
| The default value is InputReplicationMode.PER_WORKER. |
| experimental_place_dataset_on_device: Boolean. Default to False. When True, |
| dataset will be placed on the device, otherwise it will remain on the |
| host. experimental_place_dataset_on_device=True can only be used with |
| experimental_replication_mode=PER_REPLICA |
| experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the |
| prefetch buffer size in the replica device memory. Users can set it |
| to 0 to completely disable prefetching behavior, or a number greater than |
| 1 to enable larger buffer size. Note that this option is still |
| valid with `experimental_fetch_to_device=False`. |
| """ |
| |
| def __new__(cls, |
| experimental_fetch_to_device=None, |
| experimental_replication_mode=InputReplicationMode.PER_WORKER, |
| experimental_place_dataset_on_device=False, |
| experimental_per_replica_buffer_size=1): |
| if experimental_fetch_to_device is None: |
| experimental_fetch_to_device = True |
| |
| return super(InputOptions, |
| cls).__new__(cls, experimental_fetch_to_device, |
| experimental_replication_mode, |
| experimental_place_dataset_on_device, |
| experimental_per_replica_buffer_size) |
| |
| # ------------------------------------------------------------------------------ |
| # Base classes for all distribution strategies. |
| |
| |
| # Base class for v1 Strategy and v2 Strategy classes. For API's specific to |
| # v1/v2 Strategy, add to implementing classes of StrategyBase. |
| # pylint: disable=line-too-long |
| class StrategyBase(object): |
| """A state & compute distribution policy on a list of devices. |
| |
| See [the guide](https://www.tensorflow.org/guide/distributed_training) |
| for overview and examples. See `tf.distribute.StrategyExtended` and |
| [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) |
| for a glossary of concepts mentioned on this page such as "per-replica", |
| _replica_, and _reduce_. |
| |
| In short: |
| |
| * To use it with Keras `compile`/`fit`, |
| [please |
| read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). |
| * You may pass descendant of `tf.distribute.Strategy` to |
| `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` |
| should distribute its computation. See |
| [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). |
| * Otherwise, use `tf.distribute.Strategy.scope` to specify that a |
| strategy should be used when building an executing your model. |
| (This puts you in the "cross-replica context" for this strategy, which |
| means the strategy is put in control of things like variable placement.) |
| * If you are writing a custom training loop, you will need to call a few more |
| methods, |
| [see the |
| guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): |
| |
| * Start by creating a `tf.data.Dataset` normally. |
| * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert |
| a `tf.data.Dataset` to something that produces "per-replica" values. |
| If you want to manually specify how the dataset should be partitioned |
| across replicas, use |
| `tf.distribute.Strategy.distribute_datasets_from_function` |
| instead. |
| * Use `tf.distribute.Strategy.run` to run a function |
| once per replica, taking values that may be "per-replica" (e.g. |
| from a `tf.distribute.DistributedDataset` object) and returning |
| "per-replica" values. |
| This function is executed in "replica context", which means each |
| operation is performed separately on each replica. |
| * Finally use a method (such as `tf.distribute.Strategy.reduce`) to |
| convert the resulting "per-replica" values into ordinary `Tensor`s. |
| |
| A custom training loop can be as simple as: |
| |
| ``` |
| with my_strategy.scope(): |
| @tf.function |
| def distribute_train_epoch(dataset): |
| def replica_fn(input): |
| # process input and return result |
| return result |
| |
| total_result = 0 |
| for x in dataset: |
| per_replica_result = my_strategy.run(replica_fn, args=(x,)) |
| total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, |
| per_replica_result, axis=None) |
| return total_result |
| |
| dist_dataset = my_strategy.experimental_distribute_dataset(dataset) |
| for _ in range(EPOCHS): |
| train_result = distribute_train_epoch(dist_dataset) |
| ``` |
| |
| This takes an ordinary `dataset` and `replica_fn` and runs it |
| distributed using a particular `tf.distribute.Strategy` named |
| `my_strategy` above. Any variables created in `replica_fn` are created |
| using `my_strategy`'s policy, and library functions called by |
| `replica_fn` can use the `get_replica_context()` API to implement |
| distributed-specific behavior. |
| |
| You can use the `reduce` API to aggregate results across replicas and use |
| this as a return value from one iteration over a |
| `tf.distribute.DistributedDataset`. Or |
| you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to |
| accumulate metrics across steps in a given epoch. |
| |
| See the |
| [custom training loop |
| tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) |
| for a more detailed example. |
| |
| Note: `tf.distribute.Strategy` currently does not support TensorFlow's |
| partitioned variables (where a single variable is split across multiple |
| devices) at this time. |
| """ |
| # pylint: enable=line-too-long |
| |
| # TODO(josh11b): Partitioned computations, state; sharding |
| # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling |
| |
| def __init__(self, extended): |
| self._extended = extended |
| |
| # Flag that is used to indicate whether distribution strategy is used with |
| # Estimator. This is required for backward compatibility of loss scaling |
| # when using v1 optimizer with estimator. |
| self._scale_loss_for_estimator = False |
| |
| if not hasattr(extended, "_retrace_functions_for_each_device"): |
| # pylint: disable=protected-access |
| # `extended._retrace_functions_for_each_device` dictates |
| # whether the same function will be retraced when it is called on |
| # different devices. |
| try: |
| extended._retrace_functions_for_each_device = ( |
| len(extended.worker_devices) > 1) |
| distribution_strategy_replica_gauge.get_cell("num_replicas").set( |
| self.num_replicas_in_sync) |
| except: # pylint: disable=bare-except |
| # Default for the case where extended.worker_devices can't return |
| # a sensible value. |
| extended._retrace_functions_for_each_device = True |
| |
| # Below are the dicts of axis(int) -> `tf.function`. |
| self._mean_reduce_helper_fns = {} |
| self._reduce_sum_fns = {} |
| |
| # Whether this strategy is designed to work with `ClusterCoordinator`. |
| self._should_use_with_coordinator = False |
| |
| @property |
| def extended(self): |
| """`tf.distribute.StrategyExtended` with additional methods.""" |
| return self._extended |
| |
| @tf_contextlib.contextmanager |
| def _scale_loss_for_estimator_enabled(self): |
| """Scope which sets a flag used for scaling losses in optimizer. |
| |
| Yields: |
| `_scale_loss_for_estimator_enabled` is a context manager with a |
| side effect, but doesn't return a value. |
| """ |
| self._scale_loss_for_estimator = True |
| try: |
| yield |
| finally: |
| self._scale_loss_for_estimator = False |
| |
| # pylint: disable=line-too-long |
| def scope(self): |
| """Context manager to make the strategy current and distribute variables. |
| |
| This method returns a context manager, and is used as follows: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> # Variable created inside scope: |
| >>> with strategy.scope(): |
| ... mirrored_variable = tf.Variable(1.) |
| >>> mirrored_variable |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> |
| } |
| >>> # Variable created outside scope: |
| >>> regular_variable = tf.Variable(1.) |
| >>> regular_variable |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| _What happens when Strategy.scope is entered?_ |
| |
| * `strategy` is installed in the global context as the "current" strategy. |
| Inside this scope, `tf.distribute.get_strategy()` will now return this |
| strategy. Outside this scope, it returns the default no-op strategy. |
| * Entering the scope also enters the "cross-replica context". See |
| `tf.distribute.StrategyExtended` for an explanation on cross-replica and |
| replica contexts. |
| * Variable creation inside `scope` is intercepted by the strategy. Each |
| strategy defines how it wants to affect the variable creation. Sync |
| strategies like `MirroredStrategy`, `TPUStrategy` and |
| `MultiWorkerMiroredStrategy` create variables replicated on each replica, |
| whereas `ParameterServerStrategy` creates variables on the parameter |
| servers. This is done using a custom `tf.variable_creator_scope`. |
| * In some strategies, a default device scope may also be entered: in |
| `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is |
| entered on each worker. |
| |
| Note: Entering a scope does not automatically distribute a computation, except |
| in the case of high level training framework like keras `model.fit`. If |
| you're not using `model.fit`, you |
| need to use `strategy.run` API to explicitly distribute that computation. |
| See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). |
| |
| |
| _What should be in scope and what should be outside?_ |
| |
| There are a number of requirements on what needs to happen inside the scope. |
| However, in places where we have information about which strategy is in use, |
| we often enter the scope for the user, so they don't have to do it |
| explicitly (i.e. calling those either inside or outside the scope is OK). |
| |
| * Anything that creates variables that should be distributed variables |
| must be called in a `strategy.scope`. This can be accomplished either by |
| directly calling the variable creating function within the scope context, |
| or by relying on another API like `strategy.run` or `keras.Model.fit` to |
| automatically enter it for you. Any variable that is created outside scope |
| will not be distributed and may have performance implications. Some common |
| objects that create variables in TF are Models, Optimizers, Metrics. Such |
| objects should always be initialized in the scope, and any functions |
| that may lazily create variables (e.g., `Model.__call__()`, tracing a |
| `tf.function`, etc.) should similarly be called within scope. Another |
| source of variable creation can be a checkpoint restore - when variables |
| are created lazily. Note that any variable created inside a strategy |
| captures the strategy information. So reading and writing to these |
| variables outside the `strategy.scope` can also work seamlessly, without |
| the user having to enter the scope. |
| * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which |
| require to be in a strategy's scope, enter the scope automatically, which |
| means when using those APIs you don't need to explicitly enter the scope |
| yourself. |
| * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model |
| object captures the scope information. When high level training framework |
| methods such as `model.compile`, `model.fit`, etc. are then called, the |
| captured scope will be automatically entered, and the associated strategy |
| will be used to distribute the training etc. See a detailed example in |
| [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). |
| WARNING: Simply calling `model(..)` does not automatically enter the |
| captured scope -- only high level training framework APIs support this |
| behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict` |
| and `model.save` can all be called inside or outside the scope. |
| * The following can be either inside or outside the scope: |
| * Creating the input datasets |
| * Defining `tf.function`s that represent your training step |
| * Saving APIs such as `tf.saved_model.save`. Loading creates variables, |
| so that should go inside the scope if you want to train the model in a |
| distributed way. |
| * Checkpoint saving. As mentioned above - `checkpoint.restore` may |
| sometimes need to be inside scope if it creates variables. |
| |
| Returns: |
| A context manager. |
| """ |
| return self._extended._scope(self) # pylint: disable=protected-access |
| # pylint: enable=line-too-long |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` |
| def colocate_vars_with(self, colocate_with_variable): |
| """DEPRECATED: use extended.colocate_vars_with() instead.""" |
| return self._extended.colocate_vars_with(colocate_with_variable) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_dataset_iterator(self, dataset): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """DEPRECATED TF 1.x ONLY.""" |
| if replication_mode != InputReplicationMode.PER_WORKER: |
| raise ValueError( |
| "Input replication mode not supported: %r" % replication_mode) |
| with self.scope(): |
| return self.extended._make_input_fn_iterator( # pylint: disable=protected-access |
| input_fn, replication_mode=replication_mode) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def experimental_run(self, fn, input_iterator=None): |
| """DEPRECATED TF 1.x ONLY.""" |
| with self.scope(): |
| args = (input_iterator.get_next(),) if input_iterator is not None else () |
| return self.run(fn, args=args) |
| |
| def experimental_distribute_dataset(self, dataset, options=None): |
| # pylint: disable=line-too-long |
| """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. |
| |
| The returned `tf.distribute.DistributedDataset` can be iterated over |
| similar to regular datasets. |
| NOTE: The user cannot add any more transformations to a |
| `tf.distribute.DistributedDataset`. You can only create an iterator or |
| examine the `tf.TypeSpec` of the data generated by it. See API docs of |
| `tf.distribute.DistributedDataset` to learn more. |
| |
| The following is an example: |
| |
| >>> global_batch_size = 2 |
| >>> # Passing the devices is optional. |
| ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) |
| >>> # Create a dataset |
| ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) |
| >>> # Distribute that dataset |
| ... dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| >>> @tf.function |
| ... def replica_fn(input): |
| ... return input*2 |
| >>> result = [] |
| >>> # Iterate over the `tf.distribute.DistributedDataset` |
| ... for x in dist_dataset: |
| ... # process dataset elements |
| ... result.append(strategy.run(replica_fn, args=(x,))) |
| >>> print(result) |
| [PerReplica:{ |
| 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, |
| 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> |
| }, PerReplica:{ |
| 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, |
| 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> |
| }] |
| |
| |
| Three key actions happening under the hood of this method are batching, |
| sharding, and prefetching. |
| |
| In the code snippet above, `dataset` is batched by `global_batch_size`, and |
| calling `experimental_distribute_dataset` on it rebatches `dataset` to a |
| new batch size that is equal to the global batch size divided by the number |
| of replicas in sync. We iterate through it using a Pythonic for loop. |
| `x` is a `tf.distribute.DistributedValues` containing data for all replicas, |
| and each replica gets data of the new batch size. |
| `tf.distribute.Strategy.run` will take care of feeding the right per-replica |
| data in `x` to the right `replica_fn` executed on each replica. |
| |
| Sharding contains autosharding across multiple workers and within every |
| worker. First, in multi-worker distributed training (i.e. when you use |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy` |
| or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of |
| workers means that each worker is assigned a subset of the entire dataset |
| (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to |
| ensure that at each step, a global batch size of non-overlapping dataset |
| elements will be processed by each worker. Autosharding has a couple of |
| different options that can be specified using |
| `tf.data.experimental.DistributeOptions`. Then, sharding within each worker |
| means the method will split the data among all the worker devices (if more |
| than one a present). This will happen regardless of multi-worker |
| autosharding. |
| |
| Note: for autosharding across multiple workers, the default mode is |
| `tf.data.experimental.AutoShardPolicy.AUTO`. This mode |
| will attempt to shard the input dataset by files if the dataset is |
| being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, |
| `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, |
| where each of the workers will read the entire dataset and only process the |
| shard assigned to it. However, if you have less than one input file per |
| worker, we suggest that you disable dataset autosharding across workers by |
| setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be |
| `tf.data.experimental.AutoShardPolicy.OFF`. |
| |
| By default, this method adds a prefetch transformation at the end of the |
| user provided `tf.data.Dataset` instance. The argument to the prefetch |
| transformation which is `buffer_size` is equal to the number of replicas in |
| sync. |
| |
| If the above batch splitting and dataset sharding logic is undesirable, |
| please use |
| `tf.distribute.Strategy.distribute_datasets_from_function` |
| instead, which does not do any automatic batching or sharding for you. |
| |
| Note: If you are using TPUStrategy, the order in which the data is processed |
| by the workers when using |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.distribute_datasets_from_function` is |
| not guaranteed. This is typically required if you are using |
| `tf.distribute` to scale prediction. You can however insert an index for |
| each element in the batch and order outputs accordingly. Refer to [this |
| snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) |
| for an example of how to order outputs. |
| |
| Note: Stateful dataset transformations are currently not supported with |
| `tf.distribute.experimental_distribute_dataset` or |
| `tf.distribute.distribute_datasets_from_function`. Any stateful |
| ops that the dataset may have are currently ignored. For example, if your |
| dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, |
| then you have a dataset graph that depends on state (i.e the random seed) on |
| the local machine where the python process is being executed. |
| |
| For a tutorial on more usage and properties of this method, refer to the |
| [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). |
| If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be sharded across all replicas using |
| the rules stated above. |
| options: `tf.distribute.InputOptions` used to control options on how this |
| dataset is distributed. |
| |
| Returns: |
| A `tf.distribute.DistributedDataset`. |
| """ |
| # pylint: enable=line-too-long |
| return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access |
| |
| def distribute_datasets_from_function(self, dataset_fn, options=None): |
| # pylint: disable=line-too-long |
| """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. |
| |
| The argument `dataset_fn` that users pass in is an input function that has a |
| `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` |
| instance. It is expected that the returned dataset from `dataset_fn` is |
| already batched by per-replica batch size (i.e. global batch size divided by |
| the number of replicas in sync) and sharded. |
| `tf.distribute.Strategy.distribute_datasets_from_function` does |
| not batch or shard the `tf.data.Dataset` instance |
| returned from the input function. `dataset_fn` will be called on the CPU |
| device of each of the workers and each generates a dataset where every |
| replica on that worker will dequeue one batch of inputs (i.e. if a worker |
| has two replicas, two batches will be dequeued from the `Dataset` every |
| step). |
| |
| This method can be used for several purposes. First, it allows you to |
| specify your own batching and sharding logic. (In contrast, |
| `tf.distribute.experimental_distribute_dataset` does batching and sharding |
| for you.) For example, where |
| `experimental_distribute_dataset` is unable to shard the input files, this |
| method might be used to manually shard the dataset (avoiding the slow |
| fallback behavior in `experimental_distribute_dataset`). In cases where the |
| dataset is infinite, this sharding can be done by creating dataset replicas |
| that differ only in their random seed. |
| |
| The `dataset_fn` should take an `tf.distribute.InputContext` instance where |
| information about batching and input replication can be accessed. |
| |
| You can use `element_spec` property of the |
| `tf.distribute.DistributedDataset` returned by this API to query the |
| `tf.TypeSpec` of the elements returned by the iterator. This can be used to |
| set the `input_signature` property of a `tf.function`. Follow |
| `tf.distribute.DistributedDataset.element_spec` to see an example. |
| |
| IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a |
| per-replica batch size, unlike `experimental_distribute_dataset`, which uses |
| the global batch size. This may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Note: If you are using TPUStrategy, the order in which the data is processed |
| by the workers when using |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.distribute_datasets_from_function` is |
| not guaranteed. This is typically required if you are using |
| `tf.distribute` to scale prediction. You can however insert an index for |
| each element in the batch and order outputs accordingly. Refer to [this |
| snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) |
| for an example of how to order outputs. |
| |
| Note: Stateful dataset transformations are currently not supported with |
| `tf.distribute.experimental_distribute_dataset` or |
| `tf.distribute.distribute_datasets_from_function`. Any stateful |
| ops that the dataset may have are currently ignored. For example, if your |
| dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, |
| then you have a dataset graph that depends on state (i.e the random seed) on |
| the local machine where the python process is being executed. |
| |
| For a tutorial on more usage and properties of this method, refer to the |
| [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). |
| If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). |
| |
| Args: |
| dataset_fn: A function taking a `tf.distribute.InputContext` instance and |
| returning a `tf.data.Dataset`. |
| options: `tf.distribute.InputOptions` used to control options on how this |
| dataset is distributed. |
| |
| Returns: |
| A `tf.distribute.DistributedDataset`. |
| """ |
| # pylint: enable=line-too-long |
| return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access |
| dataset_fn, options) |
| |
| # TODO(b/162776748): Remove deprecated symbol. |
| @doc_controls.do_not_doc_inheritable |
| @deprecation.deprecated(None, "rename to distribute_datasets_from_function") |
| def experimental_distribute_datasets_from_function(self, |
| dataset_fn, |
| options=None): |
| return self.distribute_datasets_from_function(dataset_fn, options) |
| |
| def run(self, fn, args=(), kwargs=None, options=None): |
| """Invokes `fn` on each replica, with the given arguments. |
| |
| This method is the primary way to distribute your computation with a |
| tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` |
| have `tf.distribute.DistributedValues`, such as those produced by a |
| `tf.distribute.DistributedDataset` from |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.distribute_datasets_from_function`, |
| when `fn` is executed on a particular replica, it will be executed with the |
| component of `tf.distribute.DistributedValues` that correspond to that |
| replica. |
| |
| `fn` is invoked under a replica context. `fn` may call |
| `tf.distribute.get_replica_context()` to access members such as |
| `all_reduce`. Please see the module-level docstring of tf.distribute for the |
| concept of replica context. |
| |
| All arguments in `args` or `kwargs` can be a nested structure of tensors, |
| e.g. a list of tensors, in which case `args` and `kwargs` will be passed to |
| the `fn` invoked on each replica. Or `args` or `kwargs` can be |
| `tf.distribute.DistributedValues` containing tensors or composite tensors, |
| i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call |
| will get the component of a `tf.distribute.DistributedValues` corresponding |
| to its replica. Note that arbitrary Python values that are not of the types |
| above are not supported. |
| |
| IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and |
| whether eager execution is enabled, `fn` may be called one or more times. If |
| `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is |
| called inside a `tf.function` (eager execution is disabled inside a |
| `tf.function` by default), `fn` is called once per replica to generate a |
| Tensorflow graph, which will then be reused for execution with new inputs. |
| Otherwise, if eager execution is enabled, `fn` will be called once per |
| replica every step just like regular python code. |
| |
| Example usage: |
| |
| 1. Constant tensor input. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> tensor_input = tf.constant(3.0) |
| >>> @tf.function |
| ... def replica_fn(input): |
| ... return input*2.0 |
| >>> result = strategy.run(replica_fn, args=(tensor_input,)) |
| >>> result |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, |
| 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> |
| } |
| |
| 2. DistributedValues input. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> @tf.function |
| ... def run(): |
| ... def value_fn(value_context): |
| ... return value_context.num_replicas_in_sync |
| ... distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| ... def replica_fn2(input): |
| ... return input*2 |
| ... return strategy.run(replica_fn2, args=(distributed_values,)) |
| >>> result = run() |
| >>> result |
| <tf.Tensor: shape=(), dtype=int32, numpy=4> |
| |
| 3. Use `tf.distribute.ReplicaContext` to allreduce values. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) |
| >>> @tf.function |
| ... def run(): |
| ... def value_fn(value_context): |
| ... return tf.constant(value_context.replica_id_in_sync_group) |
| ... distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| ... def replica_fn(input): |
| ... return tf.distribute.get_replica_context().all_reduce("sum", input) |
| ... return strategy.run(replica_fn, args=(distributed_values,)) |
| >>> result = run() |
| >>> result |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>, |
| 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> |
| } |
| |
| Args: |
| fn: The function to run on each replica. |
| args: Optional positional arguments to `fn`. Its element can be a tensor, |
| a nested structure of tensors or a `tf.distribute.DistributedValues`. |
| kwargs: Optional keyword arguments to `fn`. Its element can be a tensor, |
| a nested structure of tensors or a `tf.distribute.DistributedValues`. |
| options: An optional instance of `tf.distribute.RunOptions` specifying |
| the options to run `fn`. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be `tf.distribute.DistributedValues`, `Tensor` |
| objects, or `Tensor`s (for example, if running on a single replica). |
| """ |
| del options |
| |
| if not isinstance(args, (list, tuple)): |
| raise ValueError( |
| "positional args must be a list or tuple, got {}".format(type(args))) |
| |
| with self.scope(): |
| # tf.distribute supports Eager functions, so AutoGraph should not be |
| # applied when the caller is also in Eager mode. |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) |
| |
| def reduce(self, reduce_op, value, axis): |
| """Reduce `value` across replicas and return result on current device. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def step_fn(): |
| ... i = tf.distribute.get_replica_context().replica_id_in_sync_group |
| ... return tf.identity(i) |
| >>> |
| >>> per_replica_result = strategy.run(step_fn) |
| >>> total = strategy.reduce("SUM", per_replica_result, axis=None) |
| >>> total |
| <tf.Tensor: shape=(), dtype=int32, numpy=1> |
| |
| To see how this would look with multiple replicas, consider the same |
| example with MirroredStrategy with 2 GPUs: |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) |
| def step_fn(): |
| i = tf.distribute.get_replica_context().replica_id_in_sync_group |
| return tf.identity(i) |
| |
| per_replica_result = strategy.run(step_fn) |
| # Check devices on which per replica result is: |
| strategy.experimental_local_results(per_replica_result)[0].device |
| # /job:localhost/replica:0/task:0/device:GPU:0 |
| strategy.experimental_local_results(per_replica_result)[1].device |
| # /job:localhost/replica:0/task:0/device:GPU:1 |
| |
| total = strategy.reduce("SUM", per_replica_result, axis=None) |
| # Check device on which reduced result is: |
| total.device |
| # /job:localhost/replica:0/task:0/device:CPU:0 |
| |
| ``` |
| |
| This API is typically used for aggregating the results returned from |
| different replicas, for reporting etc. For example, loss computed from |
| different replicas can be averaged using this API before printing. |
| |
| Note: The result is copied to the "current" device - which would typically |
| be the CPU of the worker on which the program is running. For `TPUStrategy`, |
| it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, |
| this is CPU of each worker. |
| |
| There are a number of different tf.distribute APIs for reducing values |
| across replicas: |
| * `tf.distribute.ReplicaContext.all_reduce`: This differs from |
| `Strategy.reduce` in that it is for replica context and does |
| not copy the results to the host device. `all_reduce` should be typically |
| used for reductions inside the training step such as gradients. |
| * `tf.distribute.StrategyExtended.reduce_to` and |
| `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more |
| advanced versions of `Strategy.reduce` as they allow customizing the |
| destination of the result. They are also called in cross replica context. |
| |
| _What should axis be?_ |
| |
| Given a per-replica value returned by `run`, say a |
| per-example loss, the batch will be divided across all the replicas. This |
| function allows you to aggregate across replicas and optionally also across |
| batch elements by specifying the axis parameter accordingly. |
| |
| For example, if you have a global batch size of 8 and 2 |
| replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and |
| `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will |
| aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. |
| This is useful when each replica is computing a scalar or some other value |
| that doesn't have a "batch" dimension (like a gradient or loss). |
| ``` |
| strategy.reduce("sum", per_replica_result, axis=None) |
| ``` |
| |
| Sometimes, you will want to aggregate across both the global batch _and_ |
| all replicas. You can get this behavior by specifying the batch |
| dimension as the `axis`, typically `axis=0`. In this case it would return a |
| scalar `0+1+2+3+4+5+6+7`. |
| ``` |
| strategy.reduce("sum", per_replica_result, axis=0) |
| ``` |
| |
| If there is a last partial batch, you will need to specify an axis so |
| that the resulting shape is consistent across replicas. So if the last |
| batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you |
| would get a shape mismatch unless you specify `axis=0`. If you specify |
| `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct |
| denominator of 6. Contrast this with computing `reduce_mean` to get a |
| scalar value on each replica and this function to average those means, |
| which will weigh some values `1/8` and others `1/4`. |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` value specifying how values should |
| be combined. Allows using string representation of the enum such as |
| "SUM", "MEAN". |
| value: a `tf.distribute.DistributedValues` instance, e.g. returned by |
| `Strategy.run`, to be combined into a single tensor. It can also be a |
| regular tensor when used with `OneDeviceStrategy` or default strategy. |
| axis: specifies the dimension to reduce along within each |
| replica's tensor. Should typically be set to the batch dimension, or |
| `None` to only reduce across replicas (e.g. if the tensor has no batch |
| dimension). |
| |
| Returns: |
| A `Tensor`. |
| """ |
| # TODO(josh11b): support `value` being a nest. |
| _require_cross_replica_or_default_context_extended(self._extended) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if axis is None: |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op == reduce_util.ReduceOp.SUM: |
| |
| def reduce_sum(v): |
| return math_ops.reduce_sum(v, axis=axis) |
| |
| if eager_context.executing_eagerly(): |
| # As some strategies (e.g. TPUStrategy) doesn't support pure eager |
| # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be |
| # run from eager mode. Cache the tf.function by `axis` to avoid the |
| # same function to be traced again. |
| if axis not in self._reduce_sum_fns: |
| |
| def reduce_sum_fn(v): |
| return self.run(reduce_sum, args=(v,)) |
| |
| self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) |
| value = self._reduce_sum_fns[axis](value) |
| else: |
| value = self.run(reduce_sum, args=(value,)) |
| |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op != reduce_util.ReduceOp.MEAN: |
| raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " |
| "not: %r" % reduce_op) |
| # TODO(josh11b): Support list/tuple and tensor axis values. |
| if not isinstance(axis, six.integer_types): |
| raise TypeError("Expected `axis` to be an integer not: %r" % axis) |
| |
| def mean_reduce_helper(v, axis=axis): |
| """Computes the numerator and denominator on each replica.""" |
| numer = math_ops.reduce_sum(v, axis=axis) |
| if v.shape.rank is not None: |
| # Note(joshl): We support axis < 0 to be consistent with the |
| # tf.math.reduce_* operations. |
| if axis < 0: |
| if axis + v.shape.rank < 0: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| axis += v.shape.rank |
| elif axis >= v.shape.rank: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| # TF v2 returns `None` for unknown dimensions and an integer for |
| # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) |
| # or tensor_shape.Dimension(integer). `dimension_value` hides this |
| # difference, always returning `None` or an integer. |
| dim = tensor_shape.dimension_value(v.shape[axis]) |
| if dim is not None: |
| # By returning a python value in the static shape case, we can |
| # maybe get a fast path for reducing the denominator. |
| # TODO(b/151871486): Remove array_ops.identity after we fallback to |
| # simple reduction if inputs are all on CPU. |
| return numer, array_ops.identity( |
| constant_op.constant(dim, dtype=dtypes.int64)) |
| elif axis < 0: |
| axis = axis + array_ops.rank(v) |
| # TODO(b/151871486): Remove array_ops.identity after we fallback to simple |
| # reduction if inputs are all on CPU. |
| denom = array_ops.identity( |
| array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) |
| # TODO(josh11b): Should we cast denom to v.dtype here instead of after the |
| # reduce is complete? |
| return numer, denom |
| |
| if eager_context.executing_eagerly(): |
| # As some strategies (e.g. TPUStrategy) doesn't support pure eager |
| # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can |
| # be run from eager mode. Cache the tf.function by `axis` to avoid the |
| # same function to be traced again. |
| if axis not in self._mean_reduce_helper_fns: |
| |
| def mean_reduce_fn(v): |
| return self.run(mean_reduce_helper, args=(v,)) |
| |
| self._mean_reduce_helper_fns[axis] = def_function.function( |
| mean_reduce_fn) |
| numer, denom = self._mean_reduce_helper_fns[axis](value) |
| else: |
| numer, denom = self.run(mean_reduce_helper, args=(value,)) |
| |
| # TODO(josh11b): Should batch reduce here instead of doing two. |
| numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access |
| denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access |
| denom = math_ops.cast(denom, numer.dtype) |
| return math_ops.truediv(numer, denom) |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED |
| def unwrap(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| DEPRECATED: Please use `experimental_local_results` instead. |
| |
| Note: This only returns values on the workers initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, |
| `extended.call_for_each_replica()`, or a variable created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value`. If `value` represents a single |
| value, this returns `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| def experimental_local_results(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| Note: This only returns values on the worker initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, `run(), or a variable |
| created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value` where ith element corresponds to |
| ith replica. If `value` represents a single value, this returns |
| `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only |
| def group(self, value, name=None): |
| """Shortcut for `tf.group(self.experimental_local_results(value))`.""" |
| return self._extended._group(value, name) # pylint: disable=protected-access |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| return self._extended._num_replicas_in_sync # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string |
| def configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| # pylint: disable=g-doc-return-or-yield,g-doc-args |
| """DEPRECATED: use `update_config_proto` instead. |
| |
| Configures the strategy class. |
| |
| DEPRECATED: This method's functionality has been split into the strategy |
| constructor and `update_config_proto`. In the future, we will allow passing |
| cluster and config_proto to the constructor to configure the strategy. And |
| `update_config_proto` can be used to update the config_proto based on the |
| specific strategy. |
| """ |
| return self._extended._configure( # pylint: disable=protected-access |
| session_config, cluster_spec, task_type, task_id) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED |
| def update_config_proto(self, config_proto): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| def __deepcopy__(self, memo): |
| # First do a regular deepcopy of `self`. |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| for k, v in self.__dict__.items(): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| # One little fix-up: we want `result._extended` to reference `result` |
| # instead of `self`. |
| result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access |
| return result |
| |
| def __copy__(self): |
| raise RuntimeError("Must only deepcopy DistributionStrategy.") |
| |
| @property |
| def cluster_resolver(self): |
| """Returns the cluster resolver associated with this strategy. |
| |
| In general, when using a multi-worker `tf.distribute` strategy such as |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy` or |
| `tf.distribute.TPUStrategy()`, there is a |
| `tf.distribute.cluster_resolver.ClusterResolver` associated with the |
| strategy used, and such an instance is returned by this property. |
| |
| Strategies that intend to have an associated |
| `tf.distribute.cluster_resolver.ClusterResolver` must set the |
| relevant attribute, or override this property; otherwise, `None` is returned |
| by default. Those strategies should also provide information regarding what |
| is returned by this property. |
| |
| Single-worker strategies usually do not have a |
| `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this |
| property will return `None`. |
| |
| The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the |
| user needs to access information such as the cluster spec, task type or task |
| id. For example, |
| |
| ```python |
| |
| os.environ['TF_CONFIG'] = json.dumps({ |
| 'cluster': { |
| 'worker': ["localhost:12345", "localhost:23456"], |
| 'ps': ["localhost:34567"] |
| }, |
| 'task': {'type': 'worker', 'index': 0} |
| }) |
| |
| # This implicitly uses TF_CONFIG for the cluster and current task info. |
| strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() |
| |
| ... |
| |
| if strategy.cluster_resolver.task_type == 'worker': |
| # Perform something that's only applicable on workers. Since we set this |
| # as a worker above, this block will run on this particular instance. |
| elif strategy.cluster_resolver.task_type == 'ps': |
| # Perform something that's only applicable on parameter servers. Since we |
| # set this as a worker above, this block will not run on this particular |
| # instance. |
| ``` |
| |
| For more information, please see |
| `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. |
| |
| Returns: |
| The cluster resolver associated with this strategy. Returns `None` if a |
| cluster resolver is not applicable or available in this strategy. |
| """ |
| if hasattr(self.extended, "_cluster_resolver"): |
| return self.extended._cluster_resolver # pylint: disable=protected-access |
| return None |
| |
| |
| @tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring |
| class Strategy(StrategyBase): |
| |
| __doc__ = StrategyBase.__doc__ |
| |
| def experimental_distribute_values_from_function(self, value_fn): |
| """Generates `tf.distribute.DistributedValues` from `value_fn`. |
| |
| This function is to generate `tf.distribute.DistributedValues` to pass |
| into `run`, `reduce`, or other methods that take |
| distributed values when not using datasets. |
| |
| Args: |
| value_fn: The function to run to generate values. It is called for |
| each replica with `tf.distribute.ValueContext` as the sole argument. It |
| must return a Tensor or a type that can be converted to a Tensor. |
| Returns: |
| A `tf.distribute.DistributedValues` containing a value for each replica. |
| |
| Example usage: |
| |
| 1. Return constant value per replica: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(ctx): |
| ... return tf.constant(1.) |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, |
| <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) |
| |
| 2. Distribute values in array based on replica_id: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> array_value = np.array([3., 2., 1.]) |
| >>> def value_fn(ctx): |
| ... return array_value[ctx.replica_id_in_sync_group] |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (3.0, 2.0) |
| |
| 3. Specify values using num_replicas_in_sync: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(ctx): |
| ... return ctx.num_replicas_in_sync |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (2, 2) |
| |
| 4. Place values on devices and distribute: |
| |
| ``` |
| strategy = tf.distribute.TPUStrategy() |
| worker_devices = strategy.extended.worker_devices |
| multiple_values = [] |
| for i in range(strategy.num_replicas_in_sync): |
| with tf.device(worker_devices[i]): |
| multiple_values.append(tf.constant(1.0)) |
| |
| def value_fn(ctx): |
| return multiple_values[ctx.replica_id_in_sync_group] |
| |
| distributed_values = strategy. |
| experimental_distribute_values_from_function( |
| value_fn) |
| ``` |
| |
| """ |
| return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access |
| value_fn) |
| |
| def gather(self, value, axis): |
| # pylint: disable=line-too-long, protected-access |
| """Gather `value` across replicas along `axis` to the current device. |
| |
| Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like |
| object `value`, this API gathers and concatenates `value` across replicas |
| along the `axis`-th dimension. The result is copied to the "current" device, |
| which would typically be the CPU of the worker on which the program is |
| running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For |
| multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of |
| each worker. |
| |
| This API can only be called in the cross-replica context. For a counterpart |
| in the replica context, see `tf.distribute.ReplicaContext.all_gather`. |
| |
| Note: For all strategies except `tf.distribute.TPUStrategy`, the input |
| `value` on different replicas must have the same rank, and their shapes must |
| be the same in all dimensions except the `axis`-th dimension. In other |
| words, their shapes cannot be different in a dimension `d` where `d` does |
| not equal to the `axis` argument. For example, given a |
| `tf.distribute.DistributedValues` with component tensors of shape |
| `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call |
| `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or |
| `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, |
| all tensors must have exactly the same rank and same shape. |
| |
| Note: Given a `tf.distribute.DistributedValues` `value`, its component |
| tensors must have a non-zero rank. Otherwise, consider using |
| `tf.expand_dims` before gathering them. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> # A DistributedValues with component tensor of shape (2, 1) on each replica |
| ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) |
| >>> @tf.function |
| ... def run(): |
| ... return strategy.gather(distributed_values, axis=0) |
| >>> run() |
| <tf.Tensor: shape=(4, 1), dtype=int32, numpy= |
| array([[1], |
| [2], |
| [1], |
| [2]], dtype=int32)> |
| |
| |
| Consider the following example for more combinations: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) |
| >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) |
| >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) |
| >>> @tf.function |
| ... def run(axis): |
| ... return strategy.gather(distributed_values, axis=axis) |
| >>> axis=0 |
| >>> run(axis) |
| <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy= |
| array([[[0, 1, 2], |
| [3, 4, 5]], |
| [[0, 1, 2], |
| [3, 4, 5]], |
| [[0, 1, 2], |
| [3, 4, 5]], |
| [[0, 1, 2], |
| [3, 4, 5]]], dtype=int32)> |
| >>> axis=1 |
| >>> run(axis) |
| <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy= |
| array([[[0, 1, 2], |
| [3, 4, 5], |
| [0, 1, 2], |
| [3, 4, 5], |
| [0, 1, 2], |
| [3, 4, 5], |
| [0, 1, 2], |
| [3, 4, 5]]], dtype=int32)> |
| >>> axis=2 |
| >>> run(axis) |
| <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy= |
| array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], |
| [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)> |
| |
| |
| Args: |
| value: a `tf.distribute.DistributedValues` instance, e.g. returned by |
| `Strategy.run`, to be combined into a single tensor. It can also be a |
| regular tensor when used with `tf.distribute.OneDeviceStrategy` or the |
| default strategy. The tensors that constitute the DistributedValues |
| can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. |
| axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the |
| range [0, rank(value)). |
| |
| Returns: |
| A `Tensor` that's the concatenation of `value` across replicas along |
| `axis` dimension. |
| """ |
| # pylint: enable=line-too-long |
| error_message = ("tf.distribute.Strategy.gather method requires " |
| "cross-replica context, use " |
| "get_replica_context().all_gather() instead") |
| _require_cross_replica_or_default_context_extended(self._extended, |
| error_message) |
| dst = device_util.current( |
| ) or self._extended._default_device or "/device:CPU:0" |
| if isinstance(value, ops.IndexedSlices): |
| raise NotImplementedError("gather does not support IndexedSlices") |
| return self._extended._local_results( |
| self._extended._gather_to(value, dst, axis))[0] |
| |
| |
| # TF v1.x version has additional deprecated APIs |
| @tf_export(v1=["distribute.Strategy"]) |
| class StrategyV1(StrategyBase): |
| """A list of devices with a state & compute distribution policy. |
| |
| See [the guide](https://www.tensorflow.org/guide/distribute_strategy) |
| for overview and examples. |
| |
| Note: Not all `tf.distribute.Strategy` implementations currently support |
| TensorFlow's partitioned variables (where a single variable is split across |
| multiple devices) at this time. |
| """ |
| |
| def make_dataset_iterator(self, dataset): |
| """Makes an iterator for input provided via `dataset`. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| Data from the given dataset will be distributed evenly across all the |
| compute replicas. We will assume that the input dataset is batched by the |
| global batch size. With this assumption, we will make a best effort to |
| divide each batch across all the replicas (one or more workers). |
| If this effort fails, an error will be thrown, and the user should instead |
| use `make_input_fn_iterator` which provides more control to the user, and |
| does not try to divide a batch across replicas. |
| |
| The user could also use `make_input_fn_iterator` if they want to |
| customize which input is fed to which replica/worker etc. |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be distributed evenly across all |
| replicas. |
| |
| Returns: |
| An `tf.distribute.InputIterator` which returns inputs for each step of the |
| computation. User should call `initialize` on the returned iterator. |
| """ |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """Returns an iterator split across replicas created from an input function. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The `input_fn` should take an `tf.distribute.InputContext` object where |
| information about batching and input sharding can be accessed: |
| |
| ``` |
| def input_fn(input_context): |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) |
| return d.shard(input_context.num_input_pipelines, |
| input_context.input_pipeline_id) |
| with strategy.scope(): |
| iterator = strategy.make_input_fn_iterator(input_fn) |
| replica_results = strategy.experimental_run(replica_fn, iterator) |
| ``` |
| |
| The `tf.data.Dataset` returned by `input_fn` should have a per-replica |
| batch size, which may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Args: |
| input_fn: A function taking a `tf.distribute.InputContext` object and |
| returning a `tf.data.Dataset`. |
| replication_mode: an enum value of `tf.distribute.InputReplicationMode`. |
| Only `PER_WORKER` is supported currently, which means there will be |
| a single call to `input_fn` per worker. Replicas will dequeue from the |
| local `tf.data.Dataset` on their worker. |
| |
| Returns: |
| An iterator object that should first be `.initialize()`-ed. It may then |
| either be passed to `strategy.experimental_run()` or you can |
| `iterator.get_next()` to get the next value to pass to |
| `strategy.extended.call_for_each_replica()`. |
| """ |
| return super(StrategyV1, self).make_input_fn_iterator( |
| input_fn, replication_mode) |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a tf.data.Dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Note that you will likely need to use |
| tf.distribute.Strategy.experimental_distribute_dataset |
| with the returned dataset to further distribute it with the strategy. |
| |
| Example: |
| ``` |
| numpy_input = np.ones([10], dtype=np.float32) |
| dataset = strategy.experimental_make_numpy_dataset(numpy_input) |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| ``` |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be converted into a |
| dataset. Note that lists of Numpy arrays are stacked, as that is normal |
| `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| return self.extended.experimental_make_numpy_dataset( |
| numpy_input, session=session) |
| |
| def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation |
| """Runs ops in `fn` on each replica, with inputs from `input_iterator`. |
| |
| DEPRECATED: This method is not available in TF 2.x. Please switch |
| to using `run` instead. |
| |
| When eager execution is enabled, executes ops specified by `fn` on each |
| replica. Otherwise, builds a graph to execute the ops on each replica. |
| |
| Each replica will take a single, different input from the inputs provided by |
| one `get_next` call on the input iterator. |
| |
| `fn` may call `tf.distribute.get_replica_context()` to access members such |
| as `replica_id_in_sync_group`. |
| |
| IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being |
| used, and whether eager execution is enabled, `fn` may be called one or more |
| times (once for each replica). |
| |
| Args: |
| fn: The function to run. The inputs to the function must match the outputs |
| of `input_iterator.get_next()`. The output must be a `tf.nest` of |
| `Tensor`s. |
| input_iterator: (Optional) input iterator from which the inputs are taken. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be `PerReplica` (if the values are unsynchronized), |
| `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a |
| single replica). |
| """ |
| return super(StrategyV1, self).experimental_run( |
| fn, input_iterator) |
| |
| def reduce(self, reduce_op, value, axis=None): |
| return super(StrategyV1, self).reduce(reduce_op, value, axis) |
| |
| reduce.__doc__ = StrategyBase.reduce.__doc__ |
| |
| def update_config_proto(self, config_proto): |
| """Returns a copy of `config_proto` modified for use with this strategy. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The updated config has something needed to run a strategy, e.g. |
| configuration to run collective ops, or device filters to improve |
| distributed training performance. |
| |
| Args: |
| config_proto: a `tf.ConfigProto` object. |
| |
| Returns: |
| The updated copy of the `config_proto`. |
| """ |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| |
| # NOTE(josh11b): For any strategy that needs to support tf.compat.v1, |
| # instead descend from StrategyExtendedV1. |
| @tf_export("distribute.StrategyExtended", v1=[]) |
| class StrategyExtendedV2(object): |
| """Additional APIs for algorithms that need to be distribution-aware. |
| |
| Note: For most usage of `tf.distribute.Strategy`, there should be no need to |
| call these methods, since TensorFlow libraries (such as optimizers) already |
| call these methods when needed on your behalf. |
| |
| |
| Some common use cases of functions on this page: |
| |
| * _Locality_ |
| |
| `tf.distribute.DistributedValues` can have the same _locality_ as a |
| _distributed variable_, which leads to a mirrored value residing on the same |
| devices as the variable (as opposed to the compute devices). Such values may |
| be passed to a call to `tf.distribute.StrategyExtended.update` to update the |
| value of a variable. You may use |
| `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the |
| same locality as another variable. You may convert a "PerReplica" value to a |
| variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or |
| `tf.distribute.StrategyExtended.batch_reduce_to`. |
| |
| * _How to update a distributed variable_ |
| |
| A distributed variable is variables created on multiple devices. As discussed |
| in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), |
| mirrored variable and SyncOnRead variable are two examples. The standard |
| pattern for updating distributed variables is to: |
| |
| 1. In your function passed to `tf.distribute.Strategy.run`, |
| compute a list of (update, variable) pairs. For example, the update might |
| be a gradient of the loss with respect to the variable. |
| 2. Switch to cross-replica mode by calling |
| `tf.distribute.get_replica_context().merge_call()` with the updates and |
| variables as arguments. |
| 3. Call |
| `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` |
| (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` |
| (for a list of variables) to sum the updates. |
| 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update |
| its value. |
| |
| Steps 2 through 4 are done automatically by class |
| `tf.keras.optimizers.Optimizer` if you call its |
| `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. |
| |
| In fact, a higher-level solution to update a distributed variable is by |
| calling `assign` on the variable as you would do to a regular `tf.Variable`. |
| You can call the method in both _replica context_ and _cross-replica context_. |
| For a _mirrored variable_, calling `assign` in _replica context_ requires you |
| to specify the `aggregation` type in the variable constructor. In that case, |
| the context switching and sync described in steps 2 through 4 are handled for |
| you. If you call `assign` on _mirrored variable_ in _cross-replica context_, |
| you can only assign a single value or assign values from another mirrored |
| variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead |
| variable_, in _replica context_, you can simply call `assign` on it and no |
| aggregation happens under the hood. In _cross-replica context_, you can only |
| assign a single value to a SyncOnRead variable. One example case is restoring |
| from a checkpoint: if the `aggregation` type of the variable is |
| `tf.VariableAggregation.SUM`, it is assumed that replica values were added |
| before checkpointing, so at the time of restoring, the value is divided by |
| the number of replicas and then assigned to each replica; if the `aggregation` |
| type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica |
| directly. |
| |
| """ |
| |
| def __init__(self, container_strategy): |
| self._container_strategy_weakref = weakref.ref(container_strategy) |
| self._default_device = None |
| # This property is used to determine if we should set drop_remainder=True |
| # when creating Datasets from numpy array inputs. |
| self._require_static_shapes = False |
| |
| def _container_strategy(self): |
| """Get the containing `tf.distribute.Strategy`. |
| |
| This should not generally be needed except when creating a new |
| `ReplicaContext` and to validate that the caller is in the correct |
| `scope()`. |
| |
| Returns: |
| The `tf.distribute.Strategy` such that `strategy.extended` is `self`. |
| """ |
| container_strategy = self._container_strategy_weakref() |
| assert container_strategy is not None |
| return container_strategy |
| |
| def _scope(self, strategy): |
| """Implementation of tf.distribute.Strategy.scope().""" |
| |
| def creator_with_resource_vars(next_creator, **kwargs): |
| """Variable creator to use in `_CurrentDistributionContext`.""" |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["distribute_strategy"] = strategy |
| |
| # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid |
| # dereferencing a `Tensor` that is without a `name`. We still need to |
| # propagate the metadata it's holding. |
| if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): |
| checkpoint_restore_uid = kwargs[ |
| "initial_value"].checkpoint_position.restore_uid |
| kwargs["initial_value"] = kwargs["initial_value"].wrapped_value |
| elif isinstance(kwargs["initial_value"], |
| trackable.CheckpointInitialValueCallable): |
| checkpoint_restore_uid = kwargs[ |
| "initial_value"].checkpoint_position.restore_uid |
| elif (isinstance(kwargs["initial_value"], functools.partial) and |
| isinstance(kwargs["initial_value"].func, |
| trackable.CheckpointInitialValueCallable)): |
| # Some libraries (e.g, Keras) create partial function out of initializer |
| # to bind shape/dtype, for example: |
| # initial_val = functools.partial(initializer, shape, dtype=dtype) |
| # Therefore to get the restore_uid we need to examine the "func" of |
| # the partial function. |
| checkpoint_restore_uid = kwargs[ |
| "initial_value"].func.checkpoint_position.restore_uid |
| else: |
| checkpoint_restore_uid = None |
| |
| created = self._create_variable(next_creator, **kwargs) |
| |
| if checkpoint_restore_uid is not None: |
| # pylint: disable=protected-access |
| # Let the checkpointing infrastructure know that the variable was |
| # already restored so it doesn't waste memory loading the value again. |
| # In this case of CheckpointInitialValueCallable this may already be |
| # done by the final variable creator, but it doesn't hurt to do it |
| # again. |
| created._maybe_initialize_trackable() |
| created._update_uid = checkpoint_restore_uid |
| # pylint: enable=protected-access |
| return created |
| |
| def distributed_getter(getter, *args, **kwargs): |
| if not self._allow_variable_partition(): |
| if kwargs.pop("partitioner", None) is not None: |
| tf_logging.log_first_n( |
| tf_logging.WARN, "Partitioned variables are disabled when using " |
| "current tf.distribute.Strategy.", 1) |
| return getter(*args, **kwargs) |
| |
| return _CurrentDistributionContext( |
| strategy, |
| variable_scope.variable_creator_scope(creator_with_resource_vars), |
| variable_scope.variable_scope( |
| variable_scope.get_variable_scope(), |
| custom_getter=distributed_getter), self._default_device) |
| |
| def _allow_variable_partition(self): |
| return False |
| |
| def _create_variable(self, next_creator, **kwargs): |
| # Note: should support "colocate_with" argument. |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def variable_created_in_scope(self, v): |
| """Tests whether `v` was created while this strategy scope was active. |
| |
| Variables created inside the strategy scope are "owned" by it: |
| |
| >>> strategy = tf.distribute.MirroredStrategy() |
| >>> with strategy.scope(): |
| ... v = tf.Variable(1.) |
| >>> strategy.extended.variable_created_in_scope(v) |
| True |
| |
| Variables created outside the strategy are not owned by it: |
| |
| >>> strategy = tf.distribute.MirroredStrategy() |
| >>> v = tf.Variable(1.) |
| >>> strategy.extended.variable_created_in_scope(v) |
| False |
| |
| Args: |
| v: A `tf.Variable` instance. |
| |
| Returns: |
| True if `v` was created inside the scope, False if not. |
| """ |
| return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Scope that controls which devices variables will be created on. |
| |
| No operations should be added to the graph inside this scope, it |
| should only be used when creating variables (some implementations |
| work by changing variable creation, others work by using a |
| tf.compat.v1.colocate_with() scope). |
| |
| This may only be used inside `self.scope()`. |
| |
| Example usage: |
| |
| ``` |
| with strategy.scope(): |
| var1 = tf.Variable(...) |
| with strategy.extended.colocate_vars_with(var1): |
| # var2 and var3 will be created on the same device(s) as var1 |
| var2 = tf.Variable(...) |
| var3 = tf.Variable(...) |
| |
| def fn(v1, v2, v3): |
| # operates on v1 from var1, v2 from var2, and v3 from var3 |
| |
| # `fn` runs on every device `var1` is on, `var2` and `var3` will be there |
| # too. |
| strategy.extended.update(var1, fn, args=(var2, var3)) |
| ``` |
| |
| Args: |
| colocate_with_variable: A variable created in this strategy's `scope()`. |
| Variables created while in the returned context manager will be on the |
| same set of devices as `colocate_with_variable`. |
| |
| Returns: |
| A context manager. |
| """ |
| |
| def create_colocated_variable(next_creator, **kwargs): |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["colocate_with"] = colocate_with_variable |
| return next_creator(**kwargs) |
| |
| _require_strategy_scope_extended(self) |
| self._validate_colocate_with_variable(colocate_with_variable) |
| return variable_scope.variable_creator_scope(create_colocated_variable) |
| |
| def _validate_colocate_with_variable(self, colocate_with_variable): |
| """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" |
| pass |
| |
| def _make_dataset_iterator(self, dataset): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _make_input_fn_iterator(self, input_fn, replication_mode): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _distribute_datasets_from_function(self, dataset_fn, options): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _reduce(self, reduce_op, value): |
| # Default implementation until we have an implementation for each strategy. |
| dst = device_util.current() or self._default_device or "/device:CPU:0" |
| return self._local_results(self.reduce_to(reduce_op, value, dst))[0] |
| |
| def reduce_to(self, reduce_op, value, destinations, options=None): |
| """Combine (via e.g. sum or mean) values across replicas. |
| |
| `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed |
| variables. It supports both dense values and `tf.IndexedSlices`. |
| |
| This API currently can only be called in cross-replica context. Other |
| variants to reduce values across replicas are: |
| * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of |
| this API. |
| * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API |
| in replica context. It supports both batched and non-batched all-reduce. |
| * `tf.distribute.Strategy.reduce`: a more convenient method to reduce |
| to the host in cross-replica context. |
| |
| `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can |
| also pass in a `Tensor`, and the destinations will be the device of that |
| tensor. For all-reduce, pass the same to `value` and `destinations`. |
| |
| It can be used in `tf.distribute.ReplicaContext.merge_call` to write code |
| that works for all `tf.distribute.Strategy`. |
| |
| >>> @tf.function |
| ... def step_fn(var): |
| ... |
| ... def merge_fn(strategy, value, var): |
| ... # All-reduce the value. Note that `value` here is a |
| ... # `tf.distribute.DistributedValues`. |
| ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, |
| ... value, destinations=var) |
| ... strategy.extended.update(var, lambda var, value: var.assign(value), |
| ... args=(reduced,)) |
| ... |
| ... value = tf.identity(1.) |
| ... tf.distribute.get_replica_context().merge_call(merge_fn, |
| ... args=(value, var)) |
| >>> |
| >>> def run(strategy): |
| ... with strategy.scope(): |
| ... v = tf.Variable(0.) |
| ... strategy.run(step_fn, args=(v,)) |
| ... return v |
| >>> |
| >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> |
| } |
| >>> run(tf.distribute.experimental.CentralStorageStrategy( |
| ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> |
| >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` value specifying how values should |
| be combined. Allows using string representation of the enum such as |
| "SUM", "MEAN". |
| value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. |
| destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a |
| `tf.Tensor` alike object, or a device string. It specifies the devices |
| to reduce to. To perform an all-reduce, pass the same to `value` and |
| `destinations`. Note that if it's a `tf.Variable`, the value is reduced |
| to the devices of that variable, and this method doesn't update the |
| variable. |
| options: a `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. See |
| `tf.distribute.experimental.CommunicationOptions` for details of the |
| options. |
| |
| Returns: |
| A tensor or value reduced to `destinations`. |
| """ |
| if options is None: |
| options = collective_util.Options() |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| assert (reduce_op == reduce_util.ReduceOp.SUM or |
| reduce_op == reduce_util.ReduceOp.MEAN) |
| return self._reduce_to(reduce_op, value, destinations, options) |
| |
| def _reduce_to(self, reduce_op, value, destinations, options): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): |
| """Combine multiple `reduce_to` calls into one for faster execution. |
| |
| Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. |
| It's more efficient than reduce each value separately. |
| |
| This API currently can only be called in cross-replica context. Other |
| variants to reduce values across replicas are: |
| * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of |
| this API. |
| * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API |
| in replica context. It supports both batched and non-batched all-reduce. |
| * `tf.distribute.Strategy.reduce`: a more convenient method to reduce |
| to the host in cross-replica context. |
| |
| See `reduce_to` for more information. |
| |
| >>> @tf.function |
| ... def step_fn(var): |
| ... |
| ... def merge_fn(strategy, value, var): |
| ... # All-reduce the value. Note that `value` here is a |
| ... # `tf.distribute.DistributedValues`. |
| ... reduced = strategy.extended.batch_reduce_to( |
| ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] |
| ... strategy.extended.update(var, lambda var, value: var.assign(value), |
| ... args=(reduced,)) |
| ... |
| ... value = tf.identity(1.) |
| ... tf.distribute.get_replica_context().merge_call(merge_fn, |
| ... args=(value, var)) |
| >>> |
| >>> def run(strategy): |
| ... with strategy.scope(): |
| ... v = tf.Variable(0.) |
| ... strategy.run(step_fn, args=(v,)) |
| ... return v |
| >>> |
| >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> |
| } |
| >>> run(tf.distribute.experimental.CentralStorageStrategy( |
| ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> |
| >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` value specifying how values should |
| be combined. Allows using string representation of the enum such as |
| "SUM", "MEAN". |
| value_destination_pairs: a sequence of (value, destinations) pairs. See |
| `tf.distribute.Strategy.reduce_to` for descriptions. |
| options: a `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. See |
| `tf.distribute.experimental.CommunicationOptions` for details of the |
| options. |
| |
| Returns: |
| A list of reduced values, one per pair in `value_destination_pairs`. |
| """ |
| if options is None: |
| options = collective_util.Options() |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| return self._batch_reduce_to(reduce_op, value_destination_pairs, options) |
| |
| def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): |
| return [ |
| self.reduce_to(reduce_op, t, destinations=v, options=options) |
| for t, v in value_destination_pairs |
| ] |
| |
| def _replica_ctx_all_reduce(self, reduce_op, value, options=None): |
| """All-reduce `value` across all replicas so that all get the final result. |
| |
| If `value` is a nested structure of tensors, all-reduces of these tensors |
| will be batched when possible. `options` can be set to hint the batching |
| behavior. |
| |
| This API must be called in a replica context. |
| |
| Args: |
| reduce_op: A `tf.distribute.ReduceOp` value specifying how values should |
| be combined. |
| value: Value to be reduced. A tensor or a nested structure of tensors. |
| options: A `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. |
| |
| Returns: |
| A tensor or a nested strucutre of tensors with the reduced values. The |
| structure is the same as `value`. |
| """ |
| if options is None: |
| options = collective_util.Options() |
| replica_context = distribution_strategy_context.get_replica_context() |
| assert replica_context, ( |
| "`StrategyExtended._replica_ctx_all_reduce` must be called in" |
| " a replica context") |
| |
| def merge_fn(_, flat_value): |
| return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], |
| options) |
| |
| reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) |
| return nest.pack_sequence_as(value, reduced) |
| |
| def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True): |
| """Run `fn` with `args` and `kwargs` to update `var`.""" |
| # This method is called by ReplicaContext.update. Strategies who'd like to |
| # remove merge_call in this path should override this method. |
| replica_context = distribution_strategy_context.get_replica_context() |
| if not replica_context: |
| raise ValueError("`StrategyExtended._replica_ctx_update` must be called " |
| "in a replica context.") |
| |
| def merge_fn(_, *merged_args, **merged_kwargs): |
| return self.update(var, fn, merged_args, merged_kwargs, group=group) |
| |
| return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) |
| |
| def _gather_to(self, value, destinations, axis, options=None): |
| """Gather `value` across replicas along axis-th dimension to `destinations`. |
| |
| `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like |
| object, along `axis`-th dimension. It supports only dense tensors but NOT |
| sparse tensor. This API can only be called in cross-replica context. |
| |
| Args: |
| value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. |
| destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a |
| `tf.Tensor` alike object, or a device string. It specifies the devices |
| to reduce to. To perform an all-gather, pass the same to `value` and |
| `destinations`. Note that if it's a `tf.Variable`, the value is reduced |
| to the devices of that variable, and this method doesn't update the |
| variable. |
| axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the |
| range [0, rank(value)). |
| options: a `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. See |
| `tf.distribute.experimental.CommunicationOptions` for details of the |
| options. |
| |
| Returns: |
| A tensor or value gathered to `destinations`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| if options is None: |
| options = collective_util.Options() |
| return self._gather_to_implementation(value, destinations, axis, options) |
| |
| def _gather_to_implementation(self, value, destinations, axis, options): |
| raise NotImplementedError("_gather_to must be implemented in descendants") |
| |
| def _batch_gather_to(self, value_destination_pairs, axis, options=None): |
| _require_cross_replica_or_default_context_extended(self) |
| if options is None: |
| options = collective_util.Options() |
| return [ |
| self._gather_to(t, destinations=v, axis=axis, options=options) |
| for t, v in value_destination_pairs |
| ] |
| |
| def update(self, var, fn, args=(), kwargs=None, group=True): |
| """Run `fn` to update `var` using inputs mirrored to the same devices. |
| |
| `tf.distribute.StrategyExtended.update` takes a distributed variable `var` |
| to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It |
| applies `fn` to each component variable of `var` and passes corresponding |
| values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain |
| per-replica values. If they contain mirrored values, they will be unwrapped |
| before calling `fn`. For example, `fn` can be `assign_add` and `args` can be |
| a mirrored DistributedValues where each component contains the value to be |
| added to this mirrored variable `var`. Calling `update` will call |
| `assign_add` on each component variable of `var` with the corresponding |
| tensor value on that device. |
| |
| Example usage: |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 |
| devices |
| with strategy.scope(): |
| v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) |
| def update_fn(v): |
| return v.assign(1.0) |
| result = strategy.extended.update(v, update_fn) |
| # result is |
| # Mirrored:{ |
| # 0: tf.Tensor(1.0, shape=(), dtype=float32), |
| # 1: tf.Tensor(1.0, shape=(), dtype=float32) |
| # } |
| ``` |
| |
| If `var` is mirrored across multiple devices, then this method implements |
| logic as following: |
| |
| ```python |
| results = {} |
| for device, v in var: |
| with tf.device(device): |
| # args and kwargs will be unwrapped if they are mirrored. |
| results[device] = fn(v, *args, **kwargs) |
| return merged(results) |
| ``` |
| |
| Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with |
| `var`. |
| |
| Args: |
| var: Variable, possibly mirrored to multiple devices, to operate on. |
| fn: Function to call. Should take the variable as the first argument. |
| args: Tuple or list. Additional positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| By default, the merged return value of `fn` across all replicas. The |
| merged result has dependencies to make sure that if it is evaluated at |
| all, the side effects (updates) will happen on every replica. If instead |
| "group=False" is specified, this function will return a nest of lists |
| where each list has an element per replica, and the caller is responsible |
| for ensuring all elements are executed. |
| """ |
| # TODO(b/178944108): Update the documentation to relfect the fact that |
| # `update` can be called in a replica context. |
| if kwargs is None: |
| kwargs = {} |
| replica_context = distribution_strategy_context.get_replica_context() |
| # pylint: disable=protected-access |
| if (replica_context is None or replica_context is |
| distribution_strategy_context._get_default_replica_context()): |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| with self._container_strategy().scope(): |
| return self._update(var, fn, args, kwargs, group) |
| else: |
| return self._replica_ctx_update( |
| var, fn, args=args, kwargs=kwargs, group=group) |
| |
| def _update(self, var, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _local_results(self, val): |
| """Returns local results per replica as a tuple.""" |
| if isinstance(val, values.DistributedValues): |
| return val._values # pylint: disable=protected-access |
| |
| if nest.is_nested(val): |
| replica_values = [] |
| |
| def get_values(x, index): |
| if isinstance(x, values.DistributedValues): |
| return x._values[index] # pylint: disable=protected-access |
| return x |
| |
| for i in range(len(self.worker_devices)): |
| replica_values.append( |
| nest.map_structure( |
| lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop |
| val)) |
| return tuple(replica_values) |
| return (val,) |
| |
| def value_container(self, value): |
| """Returns the container that this per-replica `value` belongs to. |
| |
| Args: |
| value: A value returned by `run()` or a variable created in `scope()`. |
| |
| Returns: |
| A container that `value` belongs to. |
| If value does not belong to any container (including the case of |
| container having been destroyed), returns the value itself. |
| `value in experimental_local_results(value_container(value))` will |
| always be true. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _group(self, value, name=None): |
| """Implementation of `group`.""" |
| value = nest.flatten(self._local_results(value)) |
| |
| if len(value) != 1 or name is not None: |
| return control_flow_ops.group(value, name=name) |
| # Special handling for the common case of one op. |
| v, = value |
| if hasattr(v, "op"): |
| v = v.op |
| return v |
| |
| @property |
| def experimental_require_static_shapes(self): |
| """Returns `True` if static shape is required; `False` otherwise.""" |
| return self._require_static_shapes |
| |
| @property |
| def _num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def worker_devices(self): |
| """Returns the tuple of all devices used to for compute replica execution. |
| """ |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def parameter_devices(self): |
| """Returns the tuple of all devices used to place variables.""" |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| """Configures the strategy class.""" |
| del session_config, cluster_spec, task_type, task_id |
| |
| def _update_config_proto(self, config_proto): |
| return copy.deepcopy(config_proto) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings. |
| |
| Multi-worker training refers to the setup where the training is |
| distributed across multiple workers, as opposed to the case where |
| only a local process performs the training. This function is |
| used by higher-level APIs such as Keras' `model.fit()` to infer |
| for example whether or not a distribute coordinator should be run, |
| and thus TensorFlow servers should be started for communication |
| with other servers in the cluster, or whether or not saving/restoring |
| checkpoints is relevant for preemption fault tolerance. |
| |
| Subclasses should override this to provide whether the strategy is |
| currently in multi-worker setup. |
| |
| Experimental. Signature and implementation are subject to change. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| @tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring |
| class StrategyExtendedV1(StrategyExtendedV2): |
| |
| __doc__ = StrategyExtendedV2.__doc__ |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be distributed evenly |
| across all replicas. Note that lists of Numpy arrays are stacked, as |
| that is normal `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| return self._experimental_make_numpy_dataset(numpy_input, session=session) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def broadcast_to(self, tensor, destinations): |
| """Mirror a tensor on one device to all worker devices. |
| |
| Args: |
| tensor: A Tensor value to broadcast. |
| destinations: A mirrored variable or device string specifying the |
| destination devices to copy `tensor` to. |
| |
| Returns: |
| A value mirrored to `destinations` devices. |
| """ |
| assert destinations is not None # from old strategy.broadcast() |
| # TODO(josh11b): More docstring |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| return self._broadcast_to(tensor, destinations) |
| |
| def _broadcast_to(self, tensor, destinations): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def experimental_run_steps_on_iterator(self, |
| fn, |
| iterator, |
| iterations=1, |
| initial_loop_values=None): |
| """DEPRECATED: please use `run` instead. |
| |
| Run `fn` with input from `iterator` for `iterations` times. |
| |
| This method can be used to run a step function for training a number of |
| times using input from a dataset. |
| |
| Args: |
| fn: function to run using this distribution strategy. The function must |
| have the following signature: `def fn(context, inputs)`. `context` is an |
| instance of `MultiStepContext` that will be passed when `fn` is run. |
| `context` can be used to specify the outputs to be returned from `fn` |
| by calling `context.set_last_step_output`. It can also be used to |
| capture non tensor outputs by `context.set_non_tensor_output`. See |
| `MultiStepContext` documentation for more information. `inputs` will |
| have same type/structure as `iterator.get_next()`. Typically, `fn` |
| will use `call_for_each_replica` method of the strategy to distribute |
| the computation over multiple replicas. |
| iterator: Iterator of a dataset that represents the input for `fn`. The |
| caller is responsible for initializing the iterator as needed. |
| iterations: (Optional) Number of iterations that `fn` should be run. |
| Defaults to 1. |
| initial_loop_values: (Optional) Initial values to be passed into the |
| loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove |
| initial_loop_values argument when we have a mechanism to infer the |
| outputs of `fn`. |
| |
| Returns: |
| Returns the `MultiStepContext` object which has the following properties, |
| among other things: |
| - run_op: An op that runs `fn` `iterations` times. |
| - last_step_outputs: A dictionary containing tensors set using |
| `context.set_last_step_output`. Evaluating this returns the value of |
| the tensors after the last iteration. |
| - non_tensor_outputs: A dictionary containing anything that was set by |
| `fn` by calling `context.set_non_tensor_output`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| with self._container_strategy().scope(): |
| return self._experimental_run_steps_on_iterator(fn, iterator, iterations, |
| initial_loop_values) |
| |
| def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, |
| initial_loop_values): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def call_for_each_replica(self, fn, args=(), kwargs=None): |
| """Run `fn` once per replica. |
| |
| `fn` may call `tf.get_replica_context()` to access methods such as |
| `replica_id_in_sync_group` and `merge_call()`. |
| |
| `merge_call()` is used to communicate between the replicas and |
| re-enter the cross-replica context. All replicas pause their execution |
| having encountered a `merge_call()` call. After that the |
| `merge_fn`-function is executed. Its results are then unwrapped and |
| given back to each replica call. After that execution resumes until |
| `fn` is complete or encounters another `merge_call()`. Example: |
| |
| ```python |
| # Called once in "cross-replica" context. |
| def merge_fn(distribution, three_plus_replica_id): |
| # sum the values across replicas |
| return sum(distribution.experimental_local_results(three_plus_replica_id)) |
| |
| # Called once per replica in `distribution`, in a "replica" context. |
| def fn(three): |
| replica_ctx = tf.get_replica_context() |
| v = three + replica_ctx.replica_id_in_sync_group |
| # Computes the sum of the `v` values across all replicas. |
| s = replica_ctx.merge_call(merge_fn, args=(v,)) |
| return s + v |
| |
| with distribution.scope(): |
| # in "cross-replica" context |
| ... |
| merged_results = distribution.run(fn, args=[3]) |
| # merged_results has the values from every replica execution of `fn`. |
| # This statement prints a list: |
| print(distribution.experimental_local_results(merged_results)) |
| ``` |
| |
| Args: |
| fn: function to run (will be run once per replica). |
| args: Tuple or list with positional arguments for `fn`. |
| kwargs: Dict with keyword arguments for `fn`. |
| |
| Returns: |
| Merged return value of `fn` across all replicas. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| with self._container_strategy().scope(): |
| return self._call_for_each_replica(fn, args, kwargs) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def read_var(self, v): |
| """Reads the value of a variable. |
| |
| Returns the aggregate value of a replica-local variable, or the |
| (read-only) value of any other variable. |
| |
| Args: |
| v: A variable allocated within the scope of this `tf.distribute.Strategy`. |
| |
| Returns: |
| A tensor representing the value of `v`, aggregated across replicas if |
| necessary. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def update_non_slot( |
| self, colocate_with, fn, args=(), kwargs=None, group=True): |
| """Runs `fn(*args, **kwargs)` on `colocate_with` devices. |
| |
| Used to update non-slot variables. |
| |
| DEPRECATED: TF 1.x ONLY. |
| |
| Args: |
| colocate_with: Devices returned by `non_slot_devices()`. |
| fn: Function to execute. |
| args: Tuple or list. Positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| Return value of `fn`, possibly merged across devices. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| with self._container_strategy().scope(): |
| return self._update_non_slot(colocate_with, fn, args, kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def non_slot_devices(self, var_list): |
| """Device(s) for non-slot variables. |
| |
| DEPRECATED: TF 1.x ONLY. |
| |
| This method returns non-slot devices where non-slot variables are placed. |
| Users can create non-slot variables on these devices by using a block: |
| |
| ```python |
| with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): |
| ... |
| ``` |
| |
| Args: |
| var_list: The list of variables being optimized, needed with the |
| default `tf.distribute.Strategy`. |
| Returns: |
| A sequence of devices for non-slot variables. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _use_merge_call(self): |
| """Whether to use merge-calls inside the distributed strategy.""" |
| return True |
| |
| @property |
| def experimental_between_graph(self): |
| """Whether the strategy uses between-graph replication or not. |
| |
| This is expected to return a constant value that will not be changed |
| throughout its life cycle. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def experimental_should_init(self): |
| """Whether initialization is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_checkpoint(self): |
| """Whether checkpointing is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_save_summary(self): |
| """Whether saving summaries is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| # A note about the difference between the context managers |
| # `ReplicaContext` (defined here) and `_CurrentDistributionContext` |
| # (defined above) used by `tf.distribute.Strategy.scope()`: |
| # |
| # * a ReplicaContext is only present during a `run()` |
| # call (except during a `merge_run` call) and in such a scope it |
| # will be returned by calls to `get_replica_context()`. Implementers of new |
| # Strategy descendants will frequently also need to |
| # define a descendant of ReplicaContext, and are responsible for |
| # entering and exiting this context. |
| # |
| # * Strategy.scope() sets up a variable_creator scope that |
| # changes variable creation calls (e.g. to make mirrored |
| # variables). This is intended as an outer scope that users enter once |
| # around their model creation and graph definition. There is no |
| # anticipated need to define descendants of _CurrentDistributionContext. |
| # It sets the current Strategy for purposes of |
| # `get_strategy()` and `has_strategy()` |
| # and switches the thread mode to a "cross-replica context". |
| class ReplicaContextBase(object): |
| """A class with a collection of APIs that can be called in a replica context. |
| |
| You can use `tf.distribute.get_replica_context` to get an instance of |
| `ReplicaContext`, which can only be called inside the function passed to |
| `tf.distribute.Strategy.run`. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) |
| >>> def func(): |
| ... replica_context = tf.distribute.get_replica_context() |
| ... return replica_context.replica_id_in_sync_group |
| >>> strategy.run(func) |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, |
| 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> |
| } |
| """ |
| |
| def __init__(self, strategy, replica_id_in_sync_group): |
| """Creates a ReplicaContext. |
| |
| Args: |
| strategy: A `tf.distribute.Strategy`. |
| replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an |
| integer whenever possible to avoid issues with nested `tf.function`. It |
| accepts a `Tensor` only to be compatible with `tpu.replicate`. |
| """ |
| self._strategy = strategy |
| self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access |
| self) |
| if not (replica_id_in_sync_group is None or |
| tensor_util.is_tf_type(replica_id_in_sync_group) or |
| isinstance(replica_id_in_sync_group, int)): |
| raise ValueError( |
| "replica_id_in_sync_group can only be an integer, a Tensor or None.") |
| self._replica_id_in_sync_group = replica_id_in_sync_group |
| # We need this check because TPUContext extends from ReplicaContext and |
| # does not pass a strategy object since it is used by TPUEstimator. |
| if strategy: |
| self._local_replica_id = strategy.extended._get_local_replica_id( |
| replica_id_in_sync_group) |
| self._summary_recording_distribution_strategy = None |
| |
| @doc_controls.do_not_generate_docs |
| def __enter__(self): |
| _push_per_thread_mode(self._thread_context) |
| |
| def replica_id_is_zero(): |
| return math_ops.equal(self.replica_id_in_sync_group, |
| constant_op.constant(0)) |
| |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| self._summary_recording_distribution_strategy = ( |
| summary_state.is_recording_distribution_strategy) |
| summary_state.is_recording_distribution_strategy = replica_id_is_zero |
| |
| @doc_controls.do_not_generate_docs |
| def __exit__(self, exception_type, exception_value, traceback): |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| summary_state.is_recording_distribution_strategy = ( |
| self._summary_recording_distribution_strategy) |
| _pop_per_thread_mode() |
| |
| def merge_call(self, merge_fn, args=(), kwargs=None): |
| """Merge args across replicas and run `merge_fn` in a cross-replica context. |
| |
| This allows communication and coordination when there are multiple calls |
| to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. |
| |
| See `tf.distribute.Strategy.run` for an explanation. |
| |
| If not inside a distributed scope, this is equivalent to: |
| |
| ``` |
| strategy = tf.distribute.get_strategy() |
| with cross-replica-context(strategy): |
| return merge_fn(strategy, *args, **kwargs) |
| ``` |
| |
| Args: |
| merge_fn: Function that joins arguments from threads that are given as |
| PerReplica. It accepts `tf.distribute.Strategy` object as |
| the first argument. |
| args: List or tuple with positional per-thread arguments for `merge_fn`. |
| kwargs: Dict with keyword per-thread arguments for `merge_fn`. |
| |
| Returns: |
| The return value of `merge_fn`, except for `PerReplica` values which are |
| unpacked. |
| """ |
| require_replica_context(self) |
| if kwargs is None: |
| kwargs = {} |
| |
| merge_fn = autograph.tf_convert( |
| merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| return self._merge_call(merge_fn, args, kwargs) |
| |
| def _merge_call(self, merge_fn, args, kwargs): |
| """Default implementation for single replica.""" |
| _push_per_thread_mode( # thread-local, so not needed with multiple threads |
| distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access |
| try: |
| return merge_fn(self._strategy, *args, **kwargs) |
| finally: |
| _pop_per_thread_mode() |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas that are kept in sync.""" |
| return self._strategy.num_replicas_in_sync |
| |
| @property |
| def replica_id_in_sync_group(self): |
| """Returns the id of the replica. |
| |
| This identifies the replica among all replicas that are kept in sync. The |
| value of the replica id can range from 0 to |
| `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. |
| |
| NOTE: This is not guaranteed to be the same ID as the XLA replica ID use |
| for low-level operations such as collective_permute. |
| |
| Returns: |
| a `Tensor`. |
| """ |
| # It's important to prefer making the Tensor at call time whenever possible. |
| # Keeping Tensors in global states doesn't work well with nested |
| # tf.function, since it's possible that the tensor is generated in one func |
| # graph, and gets captured by another, which will result in a subtle "An op |
| # outside of the function building code is being passed a Graph tensor" |
| # error. Making the tensor at call time to ensure it is the same graph where |
| # it's used. However to be compatible with tpu.replicate(), |
| # self._replica_id_in_sync_group can also be a Tensor. |
| if tensor_util.is_tf_type(self._replica_id_in_sync_group): |
| return self._replica_id_in_sync_group |
| return constant_op.constant( |
| self._replica_id_in_sync_group, |
| dtypes.int32, |
| name="replica_id_in_sync_group") |
| |
| @property |
| def _replica_id(self): |
| """This is the local replica id in a given sync group.""" |
| return self._local_replica_id |
| |
| @property |
| def strategy(self): |
| """The current `tf.distribute.Strategy` object.""" |
| return self._strategy |
| |
| @property |
| @deprecation.deprecated(None, "Please avoid relying on devices property.") |
| def devices(self): |
| """Returns the devices this replica is to be executed on, as a tuple of strings. |
| |
| NOTE: For `tf.distribute.MirroredStrategy` and |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a |
| nested |
| list of device strings, e.g, [["GPU:0"]]. |
| """ |
| require_replica_context(self) |
| return (device_util.current(),) |
| |
| def all_reduce(self, reduce_op, value, options=None): |
| """All-reduces `value` across all replicas. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def step_fn(): |
| ... ctx = tf.distribute.get_replica_context() |
| ... value = tf.identity(1.) |
| ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) |
| >>> strategy.experimental_local_results(strategy.run(step_fn)) |
| (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, |
| <tf.Tensor: shape=(), dtype=float32, numpy=2.0>) |
| |
| It supports batched operations. You can pass a list of values and it |
| attempts to batch them when possible. You can also specify `options` |
| to indicate the desired batching behavior, e.g. batch the values into |
| multiple packs so that they can better overlap with computations. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def step_fn(): |
| ... ctx = tf.distribute.get_replica_context() |
| ... value1 = tf.identity(1.) |
| ... value2 = tf.identity(2.) |
| ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) |
| >>> strategy.experimental_local_results(strategy.run(step_fn)) |
| ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, |
| <tf.Tensor: shape=(), dtype=float32, numpy=4.0>], |
| [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, |
| <tf.Tensor: shape=(), dtype=float32, numpy=4.0>]) |
| |
| Note that all replicas need to participate in the all-reduce, otherwise this |
| operation hangs. Note that if there're multiple all-reduces, they need to |
| execute in the same order on all replicas. Dispatching all-reduce based on |
| conditions is usually error-prone. |
| |
| This API currently can only be called in the replica context. Other |
| variants to reduce values across replicas are: |
| * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API |
| in the cross-replica context. |
| * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and |
| all-reduce API in the cross-replica context. |
| * `tf.distribute.Strategy.reduce`: a more convenient method to reduce |
| to the host in cross-replica context. |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` value specifying how values should |
| be combined. Allows using string representation of the enum such as |
| "SUM", "MEAN". |
| value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts. |
| The structure and the shapes of the `tf.Tensor` need to be same on all |
| replicas. |
| options: a `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. See |
| `tf.distribute.experimental.CommunicationOptions` for details of the |
| options. |
| |
| Returns: |
| A nested structure of `tf.Tensor` with the reduced values. The structure |
| is the same as `value`. |
| """ |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if options is None: |
| options = collective_util.Options() |
| |
| def batch_all_reduce(strategy, *value_flat): |
| return strategy.extended.batch_reduce_to( |
| reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], |
| options) |
| |
| if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: |
| # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. |
| @custom_gradient.custom_gradient |
| def grad_wrapper(*xs): |
| ys = self.merge_call(batch_all_reduce, args=xs) |
| # The gradient of an all-sum is itself an all-sum (all-mean, likewise). |
| return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) |
| return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) |
| else: |
| # TODO(cjfj): Implement gradients for other reductions. |
| reduced = nest.pack_sequence_as( |
| value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) |
| return nest.map_structure(array_ops.prevent_gradient, reduced) |
| |
| # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient |
| # all-reduce. It would return a function returning the result of reducing `t` |
| # across all replicas. The caller would wait to call this function until they |
| # needed the reduce result, allowing an efficient implementation: |
| # * With eager execution, the reduction could be performed asynchronously |
| # in the background, not blocking until the result was needed. |
| # * When constructing a graph, it could batch up all reduction requests up |
| # to that point that the first result is needed. Most likely this can be |
| # implemented in terms of `merge_call()` and `batch_reduce_to()`. |
| |
| |
| @tf_export("distribute.ReplicaContext", v1=[]) |
| class ReplicaContext(ReplicaContextBase): |
| |
| __doc__ = ReplicaContextBase.__doc__ |
| |
| def all_gather(self, value, axis, options=None): |
| """All-gathers `value` across all replicas along `axis`. |
| |
| Note: An `all_gather` method can only be called in replica context. For |
| a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. |
| All replicas need to participate in the all-gather, otherwise this |
| operation hangs. So if `all_gather` is called in any replica, it must be |
| called in all replicas. |
| |
| Note: If there are multiple `all_gather` calls, they need to be executed in |
| the same order on all replicas. Dispatching `all_gather` based on conditions |
| is usually error-prone. |
| |
| For all strategies except `tf.distribute.TPUStrategy`, the input |
| `value` on different replicas must have the same rank, and their shapes must |
| be the same in all dimensions except the `axis`-th dimension. In other |
| words, their shapes cannot be different in a dimension `d` where `d` does |
| not equal to the `axis` argument. For example, given a |
| `tf.distribute.DistributedValues` with component tensors of shape |
| `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call |
| `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` |
| or `all_gather(..., axis=2, ...)`. However, with |
| `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and |
| same shape. |
| |
| Note: The input `value` must have a non-zero rank. Otherwise, consider using |
| `tf.expand_dims` before gathering them. |
| |
| You can pass in a single tensor to all-gather: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> @tf.function |
| ... def gather_value(): |
| ... ctx = tf.distribute.get_replica_context() |
| ... local_value = tf.constant([1, 2, 3]) |
| ... return ctx.all_gather(local_value, axis=0) |
| >>> result = strategy.run(gather_value) |
| >>> result |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, |
| 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> |
| } |
| >>> strategy.experimental_local_results(result) |
| (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], |
| dtype=int32)>, |
| <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], |
| dtype=int32)>) |
| |
| |
| You can also pass in a nested structure of tensors to all-gather, say, a |
| list: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> @tf.function |
| ... def gather_nest(): |
| ... ctx = tf.distribute.get_replica_context() |
| ... value_1 = tf.constant([1, 2, 3]) |
| ... value_2 = tf.constant([[1, 2], [3, 4]]) |
| ... # all_gather a nest of `tf.distribute.DistributedValues` |
| ... return ctx.all_gather([value_1, value_2], axis=0) |
| >>> result = strategy.run(gather_nest) |
| >>> result |
| [PerReplica:{ |
| 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, |
| 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> |
| }, PerReplica:{ |
| 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= |
| array([[1, 2], |
| [3, 4], |
| [1, 2], |
| [3, 4]], dtype=int32)>, |
| 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= |
| array([[1, 2], |
| [3, 4], |
| [1, 2], |
| [3, 4]], dtype=int32)> |
| }] |
| >>> strategy.experimental_local_results(result) |
| ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, |
| <tf.Tensor: shape=(4, 2), dtype=int32, numpy= |
| array([[1, 2], |
| [3, 4], |
| [1, 2], |
| [3, 4]], dtype=int32)>], |
| [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, |
| <tf.Tensor: shape=(4, 2), dtype=int32, numpy= |
| array([[1, 2], |
| [3, 4], |
| [1, 2], |
| [3, 4]], dtype=int32)>]) |
| |
| |
| What if you are all-gathering tensors with different shapes on different |
| replicas? Consider the following example with two replicas, where you have |
| `value` as a nested structure consisting of two items to all-gather, `a` and |
| `b`. |
| |
| On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`. |
| |
| On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`. |
| |
| Result for `all_gather` with `axis`=0 (on each of the replicas) is: |
| |
| ```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}``` |
| |
| Args: |
| value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, |
| or a `tf.distribute.DistributedValues` instance. The structure of the |
| `tf.Tensor` need to be same on all replicas. The underlying tensor |
| constructs can only be dense tensors with non-zero rank, NOT |
| `tf.IndexedSlices`. |
| axis: 0-D int32 Tensor. Dimension along which to gather. |
| options: a `tf.distribute.experimental.CommunicationOptions`. Options to |
| perform collective operations. This overrides the default options if the |
| `tf.distribute.Strategy` takes one in the constructor. See |
| `tf.distribute.experimental.CommunicationOptions` for details of the |
| options. |
| |
| Returns: |
| A nested structure of `tf.Tensor` with the gathered values. The structure |
| is the same as `value`. |
| """ |
| for v in nest.flatten(value): |
| if isinstance(v, ops.IndexedSlices): |
| raise NotImplementedError("all_gather does not support IndexedSlices") |
| |
| if options is None: |
| options = collective_util.Options() |
| |
| def batch_all_gather(strategy, *value_flat): |
| return strategy.extended._batch_gather_to( # pylint: disable=protected-access |
| [(v, _batch_reduce_destination(v)) for v in value_flat], axis, |
| options) |
| |
| @custom_gradient.custom_gradient |
| def grad_wrapper(*xs): |
| ys = self.merge_call(batch_all_gather, args=xs) |
| |
| def grad(*dy_s): |
| grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) |
| new_grads = [] |
| for i, grad in enumerate(grads): |
| input_shape = array_ops.shape(xs[i]) |
| axis_dim = array_ops.reshape(input_shape[axis], [1]) |
| with ops.control_dependencies([array_ops.identity(grads)]): |
| d = self.all_gather(axis_dim, axis=0) |
| begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) |
| end_dim = begin_dim + array_ops.shape(xs[i])[axis] |
| new_grad = array_ops.gather( |
| grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) |
| new_grads.append(new_grad) |
| return new_grads |
| |
| return ys, grad |
| |
| return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) |
| |
| def _update(self, var, fn, args=(), kwargs=None, group=True): |
| """Run `fn` to update `var` with `args` and `kwargs` in replica context. |
| |
| `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var` |
| to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. |
| `fn` applies to each component variable of `var` with corresponding input |
| values from `args` and `kwargs`. |
| |
| Example usage: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'CPU:0']) # 2 replicas |
| >>> with strategy.scope(): |
| ... distributed_variable = tf.Variable(5.0) |
| >>> distributed_variable |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0> |
| } |
| >>> def replica_fn(v): |
| ... value = tf.identity(1.0) |
| ... replica_context = tf.distribute.get_replica_context() |
| ... update_fn = lambda var, value: var.assign(value) |
| ... replica_context._update(v, update_fn, args=(value,)) |
| >>> strategy.run(replica_fn, args=(distributed_variable,)) |
| >>> distributed_variable |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> |
| } |
| |
| This API must be called in a replica context. |
| |
| Note that if `var` is a MirroredVariable (i.e., the type of variable created |
| under the scope of a synchronous strategy, and is synchronized on-write, see |
| `tf.VariableSynchronization` for more information) and `args`/`kwargs` |
| contains different values for different replicas, `var` will be dangerously |
| out of synchronization. Thus we recommend using `variable.assign(value)` as |
| long as you can, which under the hood aggregates the updates and guarantees |
| the synchronization. The case where you actually want this API instead of |
| `variable.assign(value)` is that before assigning `value` to the `variable`, |
| you'd like to conduct some pre-`assign` computation colocated with the |
| variable devices (i.e. where variables reside, for MirroredStrategy they are |
| the same as the compute device, for ParameterServerStrategy they refer to |
| parameter servers). E.g., |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas |
| with strategy.scope(): |
| v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) |
| def replica_fn(inputs): |
| value = computation(inputs) |
| replica_context = tf.distribute.get_replica_context() |
| reduced_value = replica_context.all_reduce(value) |
| |
| def update_fn(var, value): |
| # this computation will colocate with `var`'s device |
| updated_value = post_reduce_pre_update_computation(value) |
| var.assign(value) |
| |
| replica_context._update(v, update_fn, args=(reduced_value,)) |
| |
| strategy.run(replica_fn, args=(inputs,)) |
| ``` |
| |
| This code snippet is consistent across all strategies. If you directly |
| compute and use `assign` in the replica context instead of wrapping it with |
| `update`, for strategies with fewer variable devices than compute devices |
| (e.g., parameter server strategy, usually), the |
| `post_reduce_pre_update_computation` will happen |
| N==number_of_compute_devices times which is less performant. |
| |
| |
| Args: |
| var: Variable, possibly distributed to multiple devices, to operate on. |
| fn: Function to call. Should take the variable as the first argument. |
| args: Tuple or list. Additional positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. Most strategies enter a merge_call to |
| conduct update in cross-replica context, and group=True guarantees updates |
| on all replicas is executed. |
| |
| Returns: |
| The return value of `fn` for the local replica. |
| """ |
| if kwargs is None: |
| kwargs = {} |
| return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access |
| |
| |
| @tf_export(v1=["distribute.ReplicaContext"]) |
| class ReplicaContextV1(ReplicaContextBase): |
| __doc__ = ReplicaContextBase.__doc__ |
| |
| |
| def _batch_reduce_destination(x): |
| """Returns the destinations for batch all-reduce.""" |
| if isinstance(x, ops.Tensor): |
| # If this is a one device strategy. |
| return x.device |
| else: |
| return x |
| |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| _creating_default_strategy_singleton = False |
| |
| |
| class _DefaultDistributionStrategyV1(StrategyV1): |
| """Default `tf.distribute.Strategy` if none is explicitly selected.""" |
| |
| def __init__(self): |
| if not _creating_default_strategy_singleton: |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| super(_DefaultDistributionStrategyV1, |
| self).__init__(_DefaultDistributionExtended(self)) |
| |
| def __deepcopy__(self, memo): |
| del memo |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| |
| |
| class _DefaultDistributionStrategy(Strategy): |
| """Default `tf.distribute.Strategy` if none is explicitly selected.""" |
| |
| def __init__(self): |
| if not _creating_default_strategy_singleton: |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| super(_DefaultDistributionStrategy, self).__init__( |
| _DefaultDistributionExtended(self)) |
| |
| def __deepcopy__(self, memo): |
| del memo |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| |
| |
| class _DefaultDistributionContext(object): |
| """Context manager setting the default `tf.distribute.Strategy`.""" |
| |
| __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] |
| |
| def __init__(self, strategy): |
| |
| def creator(next_creator, **kwargs): |
| _require_strategy_scope_strategy(strategy) |
| return next_creator(**kwargs) |
| |
| self._var_creator_scope = variable_scope.variable_creator_scope(creator) |
| self._strategy = strategy |
| self._nested_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") |
| if self._nested_count == 0: |
| self._var_creator_scope.__enter__() |
| self._nested_count += 1 |
| return self._strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| self._nested_count -= 1 |
| if self._nested_count == 0: |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| |
| class _DefaultDistributionExtended(StrategyExtendedV1): |
| """Implementation of _DefaultDistributionStrategy.""" |
| |
| def __init__(self, container_strategy): |
| super(_DefaultDistributionExtended, self).__init__(container_strategy) |
| self._retrace_functions_for_each_device = False |
| |
| def _scope(self, strategy): |
| """Context manager setting a variable creator and `self` as current.""" |
| return _DefaultDistributionContext(strategy) |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Does not require `self.scope`.""" |
| _require_strategy_scope_extended(self) |
| return ops.colocate_with(colocate_with_variable) |
| |
| def variable_created_in_scope(self, v): |
| return v._distribute_strategy is None # pylint: disable=protected-access |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| return dataset |
| |
| def _distribute_datasets_from_function(self, dataset_fn, options): |
| return dataset_fn(InputContext()) |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| return value_fn(ValueContext()) |
| |
| def _make_dataset_iterator(self, dataset): |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| dataset = input_fn(InputContext()) |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| numpy_flat = nest.flatten(numpy_input) |
| vars_flat = tuple( |
| variable_scope.variable(array_ops.zeros(i.shape, i.dtype), |
| trainable=False, use_resource=True) |
| for i in numpy_flat |
| ) |
| for v, i in zip(vars_flat, numpy_flat): |
| numpy_dataset.init_var_from_numpy(v, i, session) |
| vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) |
| return dataset_ops.Dataset.from_tensor_slices(vars_nested) |
| |
| def _broadcast_to(self, tensor, destinations): |
| if destinations is None: |
| return tensor |
| else: |
| raise NotImplementedError("TODO") |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): |
| return fn(*args, **kwargs) |
| |
| def _reduce_to(self, reduce_op, value, destinations, options): |
| # TODO(josh11b): Use destinations? |
| del reduce_op, destinations, options |
| return value |
| |
| def _gather_to_implementation(self, value, destinations, axis, options): |
| del destinations, axis, options |
| return value |
| |
| def _update(self, var, fn, args, kwargs, group): |
| # The implementations of _update() and _update_non_slot() are identical |
| # except _update() passes `var` as the first argument to `fn()`. |
| return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): |
| # TODO(josh11b): Figure out what we should be passing to UpdateContext() |
| # once that value is used for something. |
| with UpdateContext(colocate_with): |
| result = fn(*args, **kwargs) |
| if should_group: |
| return result |
| else: |
| return nest.map_structure(self._local_results, result) |
| |
| def read_var(self, replica_local_var): |
| return array_ops.identity(replica_local_var) |
| |
| def _local_results(self, distributed_value): |
| return (distributed_value,) |
| |
| def value_container(self, value): |
| return value |
| |
| @property |
| def _num_replicas_in_sync(self): |
| return 1 |
| |
| @property |
| def worker_devices(self): |
| raise RuntimeError("worker_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| @property |
| def parameter_devices(self): |
| raise RuntimeError("parameter_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| def non_slot_devices(self, var_list): |
| return min(var_list, key=lambda x: x.name) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings.""" |
| # Default strategy doesn't indicate multi-worker training. |
| return False |
| |
| @property |
| def should_checkpoint(self): |
| return True |
| |
| @property |
| def should_save_summary(self): |
| return True |
| |
| def _get_local_replica_id(self, replica_id_in_sync_group): |
| return replica_id_in_sync_group |
| |
| def _get_replica_id_in_sync_group(self, replica_id): |
| return replica_id |
| |
| # TODO(priyag): This should inherit from `InputIterator`, once dependency |
| # issues have been resolved. |
| class DefaultInputIterator(object): |
| """Default implementation of `InputIterator` for default strategy.""" |
| |
| def __init__(self, dataset): |
| self._dataset = dataset |
| if eager_context.executing_eagerly(): |
| self._iterator = dataset_ops.make_one_shot_iterator(dataset) |
| else: |
| self._iterator = dataset_ops.make_initializable_iterator(dataset) |
| |
| def get_next(self): |
| return self._iterator.get_next() |
| |
| def get_next_as_optional(self): |
| return self._iterator.get_next_as_optional() |
| |
| @deprecated(None, "Use the iterator's `initializer` property instead.") |
| def initialize(self): |
| """Initialize underlying iterators. |
| |
| Returns: |
| A list of any initializer ops that should be run. |
| """ |
| if eager_context.executing_eagerly(): |
| self._iterator = self._dataset.make_one_shot_iterator() |
| return [] |
| else: |
| return [self._iterator.initializer] |
| |
| @property |
| def initializer(self): |
| """Returns a list of ops that initialize the iterator.""" |
| return self.initialize() |
| |
| # TODO(priyag): Delete this once all strategies use global batch size. |
| @property |
| def _global_batch_size(self): |
| """Global and per-replica batching are equivalent for this strategy.""" |
| return True |
| |
| |
| class _DefaultReplicaContext(ReplicaContext): |
| """ReplicaContext for _DefaultDistributionStrategy.""" |
| |
| @property |
| def replica_id_in_sync_group(self): |
| # Return 0 instead of a constant tensor to avoid creating a new node for |
| # users who don't use distribution strategy. |
| return 0 |
| |
| |
| # ------------------------------------------------------------------------------ |
| # We haven't yet implemented deserialization for DistributedVariables. |
| # So here we catch any attempts to deserialize variables |
| # when using distribution strategies. |
| # pylint: disable=protected-access |
| _original_from_proto = resource_variable_ops._from_proto_fn |
| |
| |
| def _from_proto_fn(v, import_scope=None): |
| if distribution_strategy_context.has_strategy(): |
| raise NotImplementedError( |
| "Deserialization of variables is not yet supported when using a " |
| "tf.distribute.Strategy.") |
| else: |
| return _original_from_proto(v, import_scope=import_scope) |
| |
| resource_variable_ops._from_proto_fn = _from_proto_fn |
| # pylint: enable=protected-access |
| |
| |
| #------------------------------------------------------------------------------- |
| # Shorthand for some methods from distribution_strategy_context. |
| _push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access |
| _get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access |
| _pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access |
| _get_default_replica_mode = ( |
| distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Metrics to track which distribution strategy is being called |
| distribution_strategy_gauge = monitoring.StringGauge( |
| "/tensorflow/api/distribution_strategy", |
| "Gauge to track the type of distribution strategy used.", "TFVersion") |
| distribution_strategy_replica_gauge = monitoring.IntGauge( |
| "/tensorflow/api/distribution_strategy/replica", |
| "Gauge to track the number of replica each distribution strategy used.", |
| "CountType") |