| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=line-too-long |
| """Library for running a computation across multiple devices. |
| |
| The intent of this library is that you can write an algorithm in a stylized way |
| and it will be usable with a variety of different `tf.distribute.Strategy` |
| implementations. Each descendant will implement a different strategy for |
| distributing the algorithm across multiple devices/machines. Furthermore, these |
| changes can be hidden inside the specific layers and other library classes that |
| need special treatment to run in a distributed setting, so that most users' |
| model definition code can run unchanged. The `tf.distribute.Strategy` API works |
| the same way with eager and graph execution. |
| |
| *Guides* |
| |
| * [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) |
| * [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) |
| |
| *Tutorials* |
| |
| * [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) |
| |
| The tutorials cover how to use `tf.distribute.Strategy` to do distributed |
| training with native Keras APIs, custom training loops, |
| and Esitmator APIs. They also cover how to save/load model when using |
| `tf.distribute.Strategy`. |
| |
| *Glossary* |
| |
| * _Data parallelism_ is where we run multiple copies of the model |
| on different slices of the input data. This is in contrast to |
| _model parallelism_ where we divide up a single copy of a model |
| across multiple devices. |
| Note: we only support data parallelism for now, but |
| hope to add support for model parallelism in the future. |
| * A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that |
| TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple |
| devices on a single machine, or be connected to devices on multiple |
| machines. Devices used to run computations are called _worker devices_. |
| Devices used to store variables are _parameter devices_. For some strategies, |
| such as `tf.distribute.MirroredStrategy`, the worker and parameter devices |
| will be the same (see mirrored variables below). For others they will be |
| different. For example, `tf.distribute.experimental.CentralStorageStrategy` |
| puts the variables on a single device (which may be a worker device or may be |
| the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the |
| variables on separate machines called _parameter servers_ (see below). |
| * A _replica_ is one copy of the model, running on one slice of the |
| input data. Right now each replica is executed on its own |
| worker device, but once we add support for model parallelism |
| a replica may span multiple worker devices. |
| * A _host_ is the CPU device on a machine with worker devices, typically |
| used for running input pipelines. |
| * A _worker_ is defined to be the physical machine(s) containing the physical |
| devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A |
| worker may contain one or more replicas, but contains at least one |
| replica. Typically one worker will correspond to one machine, but in the case |
| of very large models with model parallelism, one worker may span multiple |
| machines. We typically run one input pipeline per worker, feeding all the |
| replicas on that worker. |
| * _Synchronous_, or more commonly _sync_, training is where the updates from |
| each replica are aggregated together before updating the model variables. This |
| is in contrast to _asynchronous_, or _async_ training, where each replica |
| updates the model variables independently. You may also have replicas |
| partitioned into groups which are in sync within each group but async between |
| groups. |
| * _Parameter servers_: These are machines that hold a single copy of |
| parameters/variables, used by some strategies (right now just |
| `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want |
| to operate on a variable retrieve it at the beginning of a step and send an |
| update to be applied at the end of the step. These can in priniciple support |
| either sync or async training, but right now we only have support for async |
| training with parameter servers. Compare to |
| `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables |
| on a single device on the same machine (and does sync training), and |
| `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices |
| (see below). |
| |
| * _Replica context_ vs. _Cross-replica context_ vs _Update context_ |
| |
| A _replica context_ applies |
| when you execute the computation function that was called with `strategy.run`. |
| Conceptually, you're in replica context when executing the computation |
| function that is being replicated. |
| |
| An _update context_ is entered in a `tf.distribute.StrategyExtended.update` |
| call. |
| |
| An _cross-replica context_ is entered when you enter a `strategy.scope`. This |
| is useful for calling `tf.distribute.Strategy` methods which operate across |
| the replicas (like `reduce_to()`). By default you start in a _replica context_ |
| (the "default single _replica context_") and then some methods can switch you |
| back and forth. |
| |
| * _Distributed value_: Distributed value is represented by the base class |
| `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful |
| to represent values on multiple devices, and it contains a map from replica id |
| to values. Two representative kinds of `tf.distribute.DistributedValues` are |
| "PerReplica" and "Mirrored" values. |
| |
| "PerReplica" values exist on the worker |
| devices, with a different value for each replica. They are produced by |
| iterating through a distributed dataset returned by |
| `tf.distribute.Strategy.experimental_distribute_dataset` and |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They |
| are also the typical result returned by |
| `tf.distribute.Strategy.run`. |
| |
| "Mirrored" values are like "PerReplica" values, except we know that the value |
| on all replicas are the same. We can safely read a "Mirrored" value in a |
| cross-replica context by using the value on any replica. |
| |
| * _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple |
| replicas, like `strategy.run(fn, args=[w])` with an |
| argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will |
| have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. |
| `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on |
| device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return |
| values from `fn()`, which leads to one common object if the returned values |
| are the same object from every replica, or a `DistributedValues` object |
| otherwise. |
| |
| * _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating |
| multiple values into one value, like "sum" or "mean". If a strategy is doing |
| sync training, we will perform a reduction on the gradients to a parameter |
| from all replicas before applying the update. _All-reduce_ is an algorithm for |
| performing a reduction on values from multiple devices and making the result |
| available on all of those devices. |
| |
| * _Mirrored variables_: These are variables that are created on multiple |
| devices, where we keep the variables in sync by applying the same |
| updates to every copy. Mirrored variables are created with |
| `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. |
| Normally they are only used in synchronous training. |
| |
| * _SyncOnRead variables_ |
| |
| _SyncOnRead variables_ are created by |
| `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and |
| they are created on multiple devices. In replica context, each |
| component variable on the local replica can perform reads and writes without |
| synchronization with each other. When the |
| _SyncOnRead variable_ is read in cross-replica context, the values from |
| component variables are aggregated and returned. |
| |
| _SyncOnRead variables_ bring a lot of custom configuration difficulty to the |
| underlying logic, so we do not encourage users to instantiate and use |
| _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead |
| variables_ for use cases such as batch norm and metrics. For performance |
| reasons, we often don't need to keep these statistics in sync every step and |
| they can be accumulated on each replica independently. The only time we want |
| to sync them is reporting or checkpointing, which typically happens in |
| cross-replica context. _SyncOnRead variables_ are also often used by advanced |
| users who want to control when variable values are aggregated. For example, |
| users sometimes want to maintain gradients independently on each replica for a |
| couple of steps without aggregation. |
| |
| * _Distribute-aware layers_ |
| |
| Layers are generally called in a replica context, except when defining a |
| Keras functional model. `tf.distribute.in_cross_replica_context` will let you |
| determine which case you are in. If in a replica context, |
| the `tf.distribute.get_replica_context` function will return the default |
| replica context outside a strategy scope, `None` within a strategy scope, and |
| a `tf.distribute.ReplicaContext` object inside a strategy scope and within a |
| `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an |
| `all_reduce` method for aggregating across all replicas. |
| |
| |
| Note that we provide a default version of `tf.distribute.Strategy` that is |
| used when no other strategy is in scope, that provides the same API with |
| reasonable default behavior. |
| """ |
| # pylint: enable=line-too-long |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import copy |
| import enum # pylint: disable=g-bad-import-order |
| import threading |
| import weakref |
| |
| import six |
| |
| from tensorflow.python.autograph.core import ag_ctx as autograph_ctx |
| from tensorflow.python.autograph.impl import api as autograph |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import iterator_ops |
| from tensorflow.python.distribute import collective_util |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribution_strategy_context |
| from tensorflow.python.distribute import numpy_dataset |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.eager import context as eager_context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import monitoring |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import summary_ops_v2 |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops.losses import loss_reduction |
| from tensorflow.python.ops.losses import losses_impl |
| from tensorflow.python.platform import tf_logging |
| from tensorflow.python.training.tracking import base as trackable |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_contextlib |
| from tensorflow.python.util.deprecation import deprecated |
| from tensorflow.python.util.tf_export import tf_export |
| from tensorflow.tools.docs import doc_controls |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Context tracking whether in a strategy.update() or .update_non_slot() call. |
| |
| |
| _update_replica_id = threading.local() |
| |
| |
| def get_update_replica_id(): |
| """Get the current device if in a `tf.distribute.Strategy.update()` call.""" |
| try: |
| return _update_replica_id.current |
| except AttributeError: |
| return None |
| |
| |
| class UpdateContext(object): |
| """Context manager when you are in `update()` or `update_non_slot()`.""" |
| |
| __slots__ = ["_replica_id", "_old_replica_id"] |
| |
| def __init__(self, replica_id): |
| self._replica_id = replica_id |
| self._old_replica_id = None |
| |
| def __enter__(self): |
| self._old_replica_id = get_update_replica_id() |
| _update_replica_id.current = self._replica_id |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| del exception_type, exception_value, traceback |
| _update_replica_id.current = self._old_replica_id |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Public utility functions. |
| |
| |
| @tf_export(v1=["distribute.get_loss_reduction"]) |
| def get_loss_reduction(): |
| """`tf.distribute.ReduceOp` corresponding to the last loss reduction. |
| |
| This is used to decide whether loss should be scaled in optimizer (used only |
| for estimator + v1 optimizer use case). |
| |
| Returns: |
| `tf.distribute.ReduceOp` corresponding to the last loss reduction for |
| estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. |
| """ |
| if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access |
| # If we are not in Estimator context then return 'SUM'. We do not need to |
| # scale loss in the optimizer. |
| return reduce_util.ReduceOp.SUM |
| last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access |
| if (last_reduction == losses_impl.Reduction.SUM or |
| last_reduction == loss_reduction.ReductionV2.SUM): |
| return reduce_util.ReduceOp.SUM |
| return reduce_util.ReduceOp.MEAN |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal API for validating the current thread mode |
| |
| |
| def _require_cross_replica_or_default_context_extended(extended): |
| """Verify in cross-replica context.""" |
| context = _get_per_thread_mode() |
| cross_replica = context.cross_replica_context |
| if cross_replica is not None and cross_replica.extended is extended: |
| return |
| if context is _get_default_replica_mode(): |
| return |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| # We have an error to report, figure out the right message. |
| if context.strategy is not strategy: |
| _wrong_strategy_scope(strategy, context) |
| assert cross_replica is None |
| raise RuntimeError("Method requires being in cross-replica context, use " |
| "get_replica_context().merge_call()") |
| |
| |
| def _wrong_strategy_scope(strategy, context): |
| # Figure out the right error message. |
| if not distribution_strategy_context.has_strategy(): |
| raise RuntimeError( |
| 'Need to be inside "with strategy.scope()" for %s' % |
| (strategy,)) |
| else: |
| raise RuntimeError( |
| "Mixing different tf.distribute.Strategy objects: %s is not %s" % |
| (context.strategy, strategy)) |
| |
| |
| def require_replica_context(replica_ctx): |
| """Verify in `replica_ctx` replica context.""" |
| context = _get_per_thread_mode() |
| if context.replica_context is replica_ctx: return |
| # We have an error to report, figure out the right message. |
| if context.replica_context is None: |
| raise RuntimeError("Need to be inside `call_for_each_replica()`") |
| if context.strategy is replica_ctx.strategy: |
| # Two different ReplicaContexts with the same tf.distribute.Strategy. |
| raise RuntimeError("Mismatching ReplicaContext.") |
| raise RuntimeError( |
| "Mismatching tf.distribute.Strategy objects: %s is not %s." % |
| (context.strategy, replica_ctx.strategy)) |
| |
| |
| def _require_strategy_scope_strategy(strategy): |
| """Verify in a `strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy is strategy: return |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| def _require_strategy_scope_extended(extended): |
| """Verify in a `distribution_strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy.extended is extended: return |
| # Report error. |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal context managers used to implement the DistributionStrategy |
| # base class |
| |
| |
| class _CurrentDistributionContext(object): |
| """Context manager setting the current `tf.distribute.Strategy`. |
| |
| Also: overrides the variable creator and optionally the current device. |
| """ |
| |
| def __init__(self, |
| strategy, |
| var_creator_scope, |
| var_scope=None, |
| default_device=None): |
| self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access |
| strategy) |
| self._var_creator_scope = var_creator_scope |
| self._var_scope = var_scope |
| if default_device: |
| self._device_scope = ops.device(default_device) |
| else: |
| self._device_scope = None |
| self._same_scope_again_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| _require_cross_replica_or_default_context_extended( |
| self._context.strategy.extended) |
| self._same_scope_again_count += 1 |
| else: |
| _push_per_thread_mode(self._context) |
| if self._var_scope: |
| self._var_scope.__enter__() |
| self._var_creator_scope.__enter__() |
| if self._device_scope: |
| self._device_scope.__enter__() |
| return self._context.strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| if self._same_scope_again_count > 0: |
| self._same_scope_again_count -= 1 |
| return |
| if self._device_scope: |
| try: |
| self._device_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Device scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| if self._var_scope: |
| try: |
| self._var_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| _pop_per_thread_mode() |
| |
| |
| # TODO(yuefengz): add more replication modes. |
| @tf_export("distribute.InputReplicationMode") |
| class InputReplicationMode(enum.Enum): |
| """Replication mode for input function. |
| |
| * `PER_WORKER`: The input function will be called on each worker |
| independently, creating as many input pipelines as number of workers. |
| Replicas will dequeue from the local Dataset on their worker. |
| `tf.distribute.Strategy` doesn't manage any state sharing between such |
| separate input pipelines. |
| """ |
| PER_WORKER = "PER_WORKER" |
| |
| |
| @tf_export("distribute.InputContext") |
| class InputContext(object): |
| """A class wrapping information needed by an input function. |
| |
| This is a context class that is passed to the user's input function and |
| contains information about the compute replicas and input pipelines. The |
| number of compute replicas (in sync training) helps compute the local batch |
| size from the desired global batch size for each replica. The input pipeline |
| information can be used to return a different subset of the input in each |
| replica (for e.g. shard the input pipeline, use a different input |
| source etc). |
| """ |
| |
| __slots__ = [ |
| "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" |
| ] |
| |
| def __init__(self, |
| num_input_pipelines=1, |
| input_pipeline_id=0, |
| num_replicas_in_sync=1): |
| """Initializes an InputContext object. |
| |
| Args: |
| num_input_pipelines: the number of input pipelines in a cluster. |
| input_pipeline_id: the current input pipeline id, should be an int in |
| [0,`num_input_pipelines`). |
| num_replicas_in_sync: the number of replicas that are in sync. |
| """ |
| self._num_input_pipelines = num_input_pipelines |
| self._input_pipeline_id = input_pipeline_id |
| self._num_replicas_in_sync = num_replicas_in_sync |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns the number of compute replicas in sync.""" |
| return self._num_replicas_in_sync |
| |
| @property |
| def input_pipeline_id(self): |
| """Returns the input pipeline ID.""" |
| return self._input_pipeline_id |
| |
| @property |
| def num_input_pipelines(self): |
| """Returns the number of input pipelines.""" |
| return self._num_input_pipelines |
| |
| def get_per_replica_batch_size(self, global_batch_size): |
| """Returns the per-replica batch size. |
| |
| Args: |
| global_batch_size: the global batch size which should be divisible by |
| `num_replicas_in_sync`. |
| |
| Returns: |
| the per-replica batch size. |
| |
| Raises: |
| ValueError: if `global_batch_size` not divisible by |
| `num_replicas_in_sync`. |
| """ |
| if global_batch_size % self._num_replicas_in_sync != 0: |
| raise ValueError("The `global_batch_size` %r is not divisible by " |
| "`num_replicas_in_sync` %r " % |
| (global_batch_size, self._num_replicas_in_sync)) |
| return global_batch_size // self._num_replicas_in_sync |
| |
| def __str__(self): |
| return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( |
| self.input_pipeline_id, self.num_input_pipelines) |
| |
| |
| @tf_export("distribute.experimental.ValueContext", v1=[]) |
| class ValueContext(object): |
| """A class wrapping information needed by a distribute function. |
| |
| This is a context class that is passed to the `value_fn` in |
| `strategy.experimental_distribute_values_from_function` and contains |
| information about the compute replicas. The `num_replicas_in_sync` and |
| `replica_id` can be used to customize the value on each replica. |
| |
| Example usage: |
| |
| 1. Directly constructed. |
| |
| >>> def value_fn(context): |
| ... return context.replica_id_in_sync_group/context.num_replicas_in_sync |
| >>> context = tf.distribute.experimental.ValueContext( |
| ... replica_id_in_sync_group=2, num_replicas_in_sync=4) |
| >>> per_replica_value = value_fn(context) |
| >>> per_replica_value |
| 0.5 |
| |
| 2. Passed in by `experimental_distribute_values_from_function`. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(value_context): |
| ... return value_context.num_replicas_in_sync |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (2, 2) |
| |
| """ |
| |
| __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] |
| |
| def __init__(self, |
| replica_id_in_sync_group=0, |
| num_replicas_in_sync=1): |
| """Initializes an ValueContext object. |
| |
| Args: |
| replica_id_in_sync_group: the current replica_id, should be an int in |
| [0,`num_replicas_in_sync`). |
| num_replicas_in_sync: the number of replicas that are in sync. |
| """ |
| self._replica_id_in_sync_group = replica_id_in_sync_group |
| self._num_replicas_in_sync = num_replicas_in_sync |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns the number of compute replicas in sync.""" |
| return self._num_replicas_in_sync |
| |
| @property |
| def replica_id_in_sync_group(self): |
| """Returns the replica ID.""" |
| return self._replica_id_in_sync_group |
| |
| def __str__(self): |
| return (("tf.distribute.ValueContext(replica id {}, " |
| " total replicas in sync: ""{})") |
| .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) |
| |
| |
| @tf_export("distribute.RunOptions") |
| class RunOptions( |
| collections.namedtuple("RunOptions", [ |
| "experimental_enable_dynamic_batch_size", |
| "experimental_bucketizing_dynamic_shape", |
| ])): |
| """Run options for `strategy.run`. |
| |
| This can be used to hold some strategy specific configs. |
| |
| Attributes: |
| experimental_enable_dynamic_batch_size: Boolean. Only applies to |
| TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic |
| padder to support dynamic batch size for the inputs. Otherwise only static |
| shape inputs are allowed. |
| experimental_bucketizing_dynamic_shape: Boolean. Only applies to |
| TPUStrategy. Default to False. If True, TPUStrategy will automatic |
| bucketize inputs passed into `run` if the input shape is |
| dynamic. This is a performance optimization to reduce XLA recompilation, |
| which should not have impact on correctness. |
| """ |
| |
| def __new__(cls, |
| experimental_enable_dynamic_batch_size=True, |
| experimental_bucketizing_dynamic_shape=False): |
| return super(RunOptions, |
| cls).__new__(cls, experimental_enable_dynamic_batch_size, |
| experimental_bucketizing_dynamic_shape) |
| |
| |
| @tf_export("distribute.InputOptions", v1=[]) |
| class InputOptions( |
| collections.namedtuple("InputOptions", [ |
| "experimental_prefetch_to_device", |
| ])): |
| """Run options for `experimental_distribute_dataset(s_from_function)`. |
| |
| This can be used to hold some strategy specific configs. |
| |
| ```python |
| # Setup TPUStrategy |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tf.config.experimental_connect_to_cluster(resolver) |
| tf.tpu.experimental.initialize_tpu_system(resolver) |
| strategy = tf.distribute.TPUStrategy(resolver) |
| |
| dataset = tf.data.Dataset.range(16) |
| distributed_dataset_on_host = ( |
| strategy.experimental_distribute_dataset( |
| dataset, |
| tf.distribute.InputOptions( |
| experimental_prefetch_to_device=False))) |
| ``` |
| |
| Attributes: |
| experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset |
| elements will be prefetched to accelerator device memory. When False, |
| dataset elements are prefetched to host device memory. Must be False when |
| using TPUEmbedding API. |
| """ |
| |
| def __new__(cls, experimental_prefetch_to_device=True): |
| return super(InputOptions, cls).__new__(cls, |
| experimental_prefetch_to_device) |
| |
| # ------------------------------------------------------------------------------ |
| # Base classes for all distribution strategies. |
| |
| |
| # Base class for v1 Strategy and v2 Strategy classes. For API's specific to |
| # v1/v2 Strategy, add to implementing classes of StrategyBase. |
| # pylint: disable=line-too-long |
| class StrategyBase(object): |
| """A state & compute distribution policy on a list of devices. |
| |
| See [the guide](https://www.tensorflow.org/guide/distributed_training) |
| for overview and examples. See `tf.distribute.StrategyExtended` and |
| [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) |
| for a glossory of concepts mentioned on this page such as "per-replica", |
| _replica_, and _reduce_. |
| |
| In short: |
| |
| * To use it with Keras `compile`/`fit`, |
| [please |
| read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). |
| * You may pass descendant of `tf.distribute.Strategy` to |
| `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` |
| should distribute its computation. See |
| [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). |
| * Otherwise, use `tf.distribute.Strategy.scope` to specify that a |
| strategy should be used when building an executing your model. |
| (This puts you in the "cross-replica context" for this strategy, which |
| means the strategy is put in control of things like variable placement.) |
| * If you are writing a custom training loop, you will need to call a few more |
| methods, |
| [see the |
| guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): |
| |
| * Start by either creating a `tf.data.Dataset` normally or using |
| `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of |
| a `numpy` array. |
| * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert |
| a `tf.data.Dataset` to something that produces "per-replica" values. |
| If you want to manually specify how the dataset should be partitioned |
| across replicas, use |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` |
| instead. |
| * Use `tf.distribute.Strategy.run` to run a function |
| once per replica, taking values that may be "per-replica" (e.g. |
| from a `tf.distribute.DistributedDataset` object) and returning |
| "per-replica" values. |
| This function is executed in "replica context", which means each |
| operation is performed separately on each replica. |
| * Finally use a method (such as `tf.distribute.Strategy.reduce`) to |
| convert the resulting "per-replica" values into ordinary `Tensor`s. |
| |
| A custom training loop can be as simple as: |
| |
| ``` |
| with my_strategy.scope(): |
| @tf.function |
| def distribute_train_epoch(dataset): |
| def replica_fn(input): |
| # process input and return result |
| return result |
| |
| total_result = 0 |
| for x in dataset: |
| per_replica_result = my_strategy.run(replica_fn, args=(x,)) |
| total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, |
| per_replica_result, axis=None) |
| return total_result |
| |
| dist_dataset = my_strategy.experimental_distribute_dataset(dataset) |
| for _ in range(EPOCHS): |
| train_result = distribute_train_epoch(dist_dataset) |
| ``` |
| |
| This takes an ordinary `dataset` and `replica_fn` and runs it |
| distributed using a particular `tf.distribute.Strategy` named |
| `my_strategy` above. Any variables created in `replica_fn` are created |
| using `my_strategy`'s policy, and library functions called by |
| `replica_fn` can use the `get_replica_context()` API to implement |
| distributed-specific behavior. |
| |
| You can use the `reduce` API to aggregate results across replicas and use |
| this as a return value from one iteration over a |
| `tf.distribute.DistributedDataset`. Or |
| you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to |
| accumulate metrics across steps in a given epoch. |
| |
| See the |
| [custom training loop |
| tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) |
| for a more detailed example. |
| |
| Note: `tf.distribute.Strategy` currently does not support TensorFlow's |
| partitioned variables (where a single variable is split across multiple |
| devices) at this time. |
| """ |
| # pylint: enable=line-too-long |
| |
| # TODO(josh11b): Partitioned computations, state; sharding |
| # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling |
| |
| def __init__(self, extended): |
| self._extended = extended |
| |
| # Flag that is used to indicate whether distribution strategy is used with |
| # Estimator. This is required for backward compatibility of loss scaling |
| # when using v1 optimizer with estimator. |
| self._scale_loss_for_estimator = False |
| |
| if not hasattr(extended, "_retrace_functions_for_each_device"): |
| # pylint: disable=protected-access |
| # `extended._retrace_functions_for_each_device` dictates |
| # whether the same function will be retraced when it is called on |
| # different devices. |
| try: |
| extended._retrace_functions_for_each_device = ( |
| len(extended.worker_devices) > 1) |
| distribution_strategy_replica_gauge.get_cell("num_replicas").set( |
| self.num_replicas_in_sync) |
| except: # pylint: disable=bare-except |
| # Default for the case where extended.worker_devices can't return |
| # a sensible value. |
| extended._retrace_functions_for_each_device = True |
| |
| # Below are the dicts of axis(int) -> `tf.function`. |
| self._mean_reduce_helper_fns = {} |
| self._reduce_sum_fns = {} |
| |
| @property |
| def extended(self): |
| """`tf.distribute.StrategyExtended` with additional methods.""" |
| return self._extended |
| |
| @tf_contextlib.contextmanager |
| def _scale_loss_for_estimator_enabled(self): |
| """Scope which sets a flag used for scaling losses in optimizer. |
| |
| Yields: |
| `_scale_loss_for_estimator_enabled` is a context manager with a |
| side effect, but doesn't return a value. |
| """ |
| self._scale_loss_for_estimator = True |
| try: |
| yield |
| finally: |
| self._scale_loss_for_estimator = False |
| |
| # pylint: disable=line-too-long |
| def scope(self): |
| """Context manager to make the strategy current and distribute variables. |
| |
| This method returns a context manager, and is used as follows: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> # Variable created inside scope: |
| >>> with strategy.scope(): |
| ... mirrored_variable = tf.Variable(1.) |
| >>> mirrored_variable |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> |
| } |
| >>> # Variable created outside scope: |
| >>> regular_variable = tf.Variable(1.) |
| >>> regular_variable |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| _What happens when Strategy.scope is entered?_ |
| |
| * `strategy` is installed in the global context as the "current" strategy. |
| Inside this scope, `tf.distribute.get_strategy()` will now return this |
| strategy. Outside this scope, it returns the default no-op strategy. |
| * Entering the scope also enters the "cross-replica context". See |
| `tf.distribute.StrategyExtended` for an explanation on cross-replica and |
| replica contexts. |
| * Variable creation inside `scope` is intercepted by the strategy. Each |
| strategy defines how it wants to affect the variable creation. Sync |
| strategies like `MirroredStrategy`, `TPUStrategy` and |
| `MultiWorkerMiroredStrategy` create variables replicated on each replica, |
| whereas `ParameterServerStrategy` creates variables on the parameter |
| servers. This is done using a custom `tf.variable_creator_scope`. |
| * In some strategies, a default device scope may also be entered: in |
| `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is |
| entered on each worker. |
| |
| Note: Entering a scope does not automatically distribute a computation, except |
| in the case of high level training framework like keras `model.fit`. If |
| you're not using `model.fit`, you |
| need to use `strategy.run` API to explicitly distribute that computation. |
| See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). |
| |
| |
| _What should be in scope and what should be outside?_ |
| |
| There are a number of requirements on what needs to happen inside the scope. |
| However, in places where we have information about which strategy is in use, |
| we often enter the scope for the user, so they don't have to do it |
| explicitly (i.e. calling those either inside or outside the scope is OK). |
| |
| * Anything that creates variables that should be distributed variables |
| must be in `strategy.scope`. This can be either by directly putting it in |
| scope, or relying on another API like `strategy.run` or `model.fit` to |
| enter it for you. Any variable that is created outside scope will not be |
| distributed and may have performance implications. Common things that |
| create variables in TF: models, optimizers, metrics. These should always |
| be created inside the scope. Another source of variable creation can be |
| a checkpoint restore - when variables are created lazily. Note that any |
| variable created inside a strategy captures the strategy information. So |
| reading and writing to these variables outside the `strategy.scope` can |
| also work seamlessly, without the user having to enter the scope. |
| * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which |
| require to be in a strategy's scope, enter the scope for you |
| automatically, which means when using those APIs you don't need to |
| enter the scope yourself. |
| * When a `tf.keras.Model` is created inside a `strategy.scope`, we capture |
| this information. When high level training frameworks methods such as |
| `model.compile`, `model.fit` etc are then called |
| on this model, we automatically enter the scope, as well as use this |
| strategy to distribute the training etc. See |
| detailed example in [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). |
| Note that simply calling the `model(..)` is not impacted - only high |
| level training framework APIs are. `model.compile`, `model.fit`, |
| `model.evaluate`, `model.predict` and `model.save` can all be called |
| inside or outside the scope. |
| * The following can be either inside or outside the scope: |
| ** Creating the input datasets |
| ** Defining `tf.function`s that represent your training step |
| ** Saving APIs such as `tf.saved_model.save`. Loading creates variables, |
| so that should go inside the scope if you want to train the model in a |
| distributed way. |
| ** Checkpoint saving. As mentioned above - `checkpoint.restore` may |
| sometimes need to be inside scope if it creates variables. |
| |
| Returns: |
| A context manager. |
| """ |
| return self._extended._scope(self) # pylint: disable=protected-access |
| # pylint: enable=line-too-long |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` |
| def colocate_vars_with(self, colocate_with_variable): |
| """DEPRECATED: use extended.colocate_vars_with() instead.""" |
| return self._extended.colocate_vars_with(colocate_with_variable) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_dataset_iterator(self, dataset): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """DEPRECATED TF 1.x ONLY.""" |
| if replication_mode != InputReplicationMode.PER_WORKER: |
| raise ValueError( |
| "Input replication mode not supported: %r" % replication_mode) |
| with self.scope(): |
| return self.extended._make_input_fn_iterator( # pylint: disable=protected-access |
| input_fn, replication_mode=replication_mode) |
| |
| @deprecation.deprecated( |
| "2020-09-30", "Please use tf.data.Dataset.from_tensor_slices instead") |
| def experimental_make_numpy_dataset(self, numpy_input): |
| """Makes a `tf.data.Dataset` from a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Note that you will likely need to use `experimental_distribute_dataset` |
| with the returned dataset to further distribute it with the strategy. |
| |
| Example: |
| |
| >>> strategy = tf.distribute.MirroredStrategy() |
| >>> numpy_input = np.ones([10], dtype=np.float32) |
| >>> dataset = strategy.experimental_make_numpy_dataset(numpy_input) |
| >>> dataset |
| <TensorSliceDataset shapes: (), types: tf.float32> |
| >>> dataset = dataset.batch(2) |
| >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| |
| Args: |
| numpy_input: a nest of NumPy input arrays that will be converted into a |
| dataset. Note that the NumPy arrays are stacked, as that is normal |
| `tf.data.Dataset` behavior. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| return self.extended.experimental_make_numpy_dataset( |
| numpy_input, session=None) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def experimental_run(self, fn, input_iterator=None): |
| """DEPRECATED TF 1.x ONLY.""" |
| with self.scope(): |
| args = (input_iterator.get_next(),) if input_iterator is not None else () |
| return self.run(fn, args=args) |
| |
| def experimental_distribute_dataset(self, dataset, options=None): |
| # pylint: disable=line-too-long |
| """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. |
| |
| The returned `tf.distribute.DistributedDataset` can be iterated over |
| similar to regular datasets. |
| NOTE: The user cannot add any more transformations to a |
| `tf.distribute.DistributedDataset`. You can only create an iterator or |
| examine the `tf.TypeSpec` of the data generated by it. See API docs of |
| `tf.distribute.DistributedDataset` to learn more. |
| |
| The following is an example: |
| |
| >>> global_batch_size = 2 |
| >>> # Passing the devices is optional. |
| ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) |
| >>> # Create a dataset |
| ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) |
| >>> # Distribute that dataset |
| ... dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| >>> @tf.function |
| ... def replica_fn(input): |
| ... return input*2 |
| >>> result = [] |
| >>> # Iterate over the `tf.distribute.DistributedDataset` |
| ... for x in dist_dataset: |
| ... # process dataset elements |
| ... result.append(strategy.run(replica_fn, args=(x,))) |
| >>> print(result) |
| [PerReplica:{ |
| 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, |
| 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> |
| }, PerReplica:{ |
| 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, |
| 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> |
| }] |
| |
| |
| Three key actions happending under the hood of this method are batching, |
| sharding, and prefetching. |
| |
| In the code snippet above, `dataset` is batched by `global_batch_size`, and |
| calling `experimental_distribute_dataset` on it rebatches `dataset` to a |
| new batch size that is equal to the global batch size divided by the number |
| of replicas in sync. We iterate through it using a Pythonic for loop. |
| `x` is a `tf.distribute.DistributedValues` containing data for all replicas, |
| and each replica gets data of the new batch size. |
| `tf.distribute.Strategy.run` will take care of feeding the right per-replica |
| data in `x` to the right `replica_fn` executed on each replica. |
| |
| Sharding contains autosharding across multiple workers and within every |
| worker. First, in multi-worker distributed training (i.e. when you use |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy` |
| or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of |
| workers means that each worker is assigned a subset of the entire dataset |
| (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to |
| ensure that at each step, a global batch size of non-overlapping dataset |
| elements will be processed by each worker. Autosharding has a couple of |
| different options that can be specified using |
| `tf.data.experimental.DistributeOptions`. Then, sharding within each worker |
| means the method will split the data among all the worker devices (if more |
| than one a present). This will happen regardless of multi-worker |
| autosharding. |
| |
| Note: for autosharding across multiple workers, the default mode is |
| `tf.data.experimental.AutoShardPolicy.AUTO`. This mode |
| will attempt to shard the input dataset by files if the dataset is |
| being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, |
| `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, |
| where each of the workers will read the entire dataset and only process the |
| shard assigned to it. However, if you have less than one input file per |
| worker, we suggest that you disable dataset autosharding across workers by |
| setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be |
| `tf.data.experimental.AutoShardPolicy.OFF`. |
| |
| By default, this method adds a prefetch transformation at the end of the |
| user provided `tf.data.Dataset` instance. The argument to the prefetch |
| transformation which is `buffer_size` is equal to the number of replicas in |
| sync. |
| |
| If the above batch splitting and dataset sharding logic is undesirable, |
| please use |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` |
| instead, which does not do any automatic batching or sharding for you. |
| |
| Note: If you are using TPUStrategy, the order in which the data is processed |
| by the workers when using |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` is |
| not guaranteed. This is typically required if you are using |
| `tf.distribute` to scale prediction. You can however insert an index for |
| each element in the batch and order outputs accordingly. Refer to [this |
| snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) |
| for an example of how to order outputs. |
| |
| Note: Stateful dataset transformations are currently not supported with |
| `tf.distribute.experimental_distribute_dataset` or |
| `tf.distribute.experimental_distribute_datasets_from_function`. Any stateful |
| ops that the dataset may have are currently ignored. For example, if your |
| dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, |
| then you have a dataset graph that depends on state (i.e the random seed) on |
| the local machine where the python process is being executed. |
| |
| For a tutorial on more usage and properties of this method, refer to the |
| [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). |
| If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be sharded across all replicas using |
| the rules stated above. |
| options: `tf.distribute.InputOptions` used to control options on how this |
| dataset is distributed. |
| |
| Returns: |
| A `tf.distribute.DistributedDataset`. |
| """ |
| # pylint: enable=line-too-long |
| return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access |
| |
| def experimental_distribute_datasets_from_function(self, dataset_fn, |
| options=None): |
| # pylint: disable=line-too-long |
| """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. |
| |
| The argument `dataset_fn` that users pass in is an input function that has a |
| `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` |
| instance. It is expected that the returned dataset from `dataset_fn` is |
| already batched by per-replica batch size (i.e. global batch size divided by |
| the number of replicas in sync) and sharded. |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` does |
| not batch or shard the `tf.data.Dataset` instance |
| returned from the input function. `dataset_fn` will be called on the CPU |
| device of each of the workers and each generates a dataset where every |
| replica on that worker will dequeue one batch of inputs (i.e. if a worker |
| has two replicas, two batches will be dequeued from the `Dataset` every |
| step). |
| |
| This method can be used for several purposes. First, it allows you to |
| specify your own batching and sharding logic. (In contrast, |
| `tf.distribute.experimental_distribute_dataset` does batching and sharding |
| for you.)For example, where |
| `experimental_distribute_dataset` is unable to shard the input files, this |
| method might be used to manually shard the dataset (avoiding the slow |
| fallback behavior in `experimental_distribute_dataset`). In cases where the |
| dataset is infinite, this sharding can be done by creating dataset replicas |
| that differ only in their random seed. |
| |
| The `dataset_fn` should take an `tf.distribute.InputContext` instance where |
| information about batching and input replication can be accessed. |
| |
| You can use `element_spec` property of the |
| `tf.distribute.DistributedDataset` returned by this API to query the |
| `tf.TypeSpec` of the elements returned by the iterator. This can be used to |
| set the `input_signature` property of a `tf.function`. Follow |
| `tf.distribute.DistributedDataset.element_spec` to see an example. |
| |
| IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a |
| per-replica batch size, unlike `experimental_distribute_dataset`, which uses |
| the global batch size. This may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Note: If you are using TPUStrategy, the order in which the data is processed |
| by the workers when using |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` is |
| not guaranteed. This is typically required if you are using |
| `tf.distribute` to scale prediction. You can however insert an index for |
| each element in the batch and order outputs accordingly. Refer to [this |
| snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) |
| for an example of how to order outputs. |
| |
| Note: Stateful dataset transformations are currently not supported with |
| `tf.distribute.experimental_distribute_dataset` or |
| `tf.distribute.experimental_distribute_datasets_from_function`. Any stateful |
| ops that the dataset may have are currently ignored. For example, if your |
| dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, |
| then you have a dataset graph that depends on state (i.e the random seed) on |
| the local machine where the python process is being executed. |
| |
| For a tutorial on more usage and properties of this method, refer to the |
| [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function). |
| If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). |
| |
| Args: |
| dataset_fn: A function taking a `tf.distribute.InputContext` instance and |
| returning a `tf.data.Dataset`. |
| options: `tf.distribute.InputOptions` used to control options on how this |
| dataset is distributed. |
| |
| Returns: |
| A `tf.distribute.DistributedDataset`. |
| """ |
| # pylint: enable=line-too-long |
| return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access |
| dataset_fn, options) |
| |
| def run(self, fn, args=(), kwargs=None, options=None): |
| """Run `fn` on each replica, with the given arguments. |
| |
| Executes ops specified by `fn` on each replica. If `args` or `kwargs` have |
| `tf.distribute.DistributedValues`, such as those produced by a |
| `tf.distribute.DistributedDataset` from |
| `tf.distribute.Strategy.experimental_distribute_dataset` or |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function`, |
| when `fn` is executed on a particular replica, it will be executed with the |
| component of `tf.distribute.DistributedValues` that correspond to that |
| replica. |
| |
| `fn` may call `tf.distribute.get_replica_context()` to access members such |
| as `all_reduce`. |
| |
| All arguments in `args` or `kwargs` should either be nest of tensors or |
| `tf.distribute.DistributedValues` containing tensors or composite tensors. |
| |
| IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and |
| whether eager execution is enabled, `fn` may be called one or more times. If |
| `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is |
| called inside a `tf.function`, eager execution is disabled and `fn` is |
| called once (or once per replica, if you are using MirroredStrategy) to |
| generate a Tensorflow graph, which will then be reused for execution with |
| new inputs. Otherwise, if eager execution is enabled, `fn` will be called |
| every step just like regular python code. |
| |
| Example usage: |
| |
| 1. Constant tensor input. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> tensor_input = tf.constant(3.0) |
| >>> @tf.function |
| ... def replica_fn(input): |
| ... return input*2.0 |
| >>> result = strategy.run(replica_fn, args=(tensor_input,)) |
| >>> result |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, |
| 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> |
| } |
| |
| 2. DistributedValues input. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> @tf.function |
| ... def run(): |
| ... def value_fn(value_context): |
| ... return value_context.num_replicas_in_sync |
| ... distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| ... def replica_fn2(input): |
| ... return input*2 |
| ... return strategy.run(replica_fn2, args=(distributed_values,)) |
| >>> result = run() |
| >>> result |
| <tf.Tensor: shape=(), dtype=int32, numpy=4> |
| |
| Args: |
| fn: The function to run. The output must be a `tf.nest` of `Tensor`s. |
| args: (Optional) Positional arguments to `fn`. |
| kwargs: (Optional) Keyword arguments to `fn`. |
| options: (Optional) An instance of `tf.distribute.RunOptions` specifying |
| the options to run `fn`. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be `tf.distribute.DistributedValues`, `Tensor` |
| objects, or `Tensor`s (for example, if running on a single replica). |
| """ |
| del options |
| |
| if not isinstance(args, (list, tuple)): |
| raise ValueError( |
| "positional args must be a list or tuple, got {}".format(type(args))) |
| |
| with self.scope(): |
| # tf.distribute supports Eager functions, so AutoGraph should not be |
| # applied when when the caller is also in Eager mode. |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) |
| |
| def reduce(self, reduce_op, value, axis): |
| """Reduce `value` across replicas and return result on current device. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def step_fn(): |
| ... i = tf.distribute.get_replica_context().replica_id_in_sync_group |
| ... return tf.identity(i) |
| >>> |
| >>> per_replica_result = strategy.run(step_fn) |
| >>> total = strategy.reduce("SUM", per_replica_result, axis=None) |
| >>> total |
| <tf.Tensor: shape=(), dtype=int32, numpy=1> |
| |
| To see how this would look with multiple replicas, consider the same |
| example with MirroredStrategy with 2 GPUs: |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) |
| def step_fn(): |
| i = tf.distribute.get_replica_context().replica_id_in_sync_group |
| return tf.identity(i) |
| |
| per_replica_result = strategy.run(step_fn) |
| # Check devices on which per replica result is: |
| strategy.experimental_local_results(per_replica_result)[0].device |
| # /job:localhost/replica:0/task:0/device:GPU:0 |
| strategy.experimental_local_results(per_replica_result)[1].device |
| # /job:localhost/replica:0/task:0/device:GPU:1 |
| |
| total = strategy.reduce("SUM", per_replica_result, axis=None) |
| # Check device on which reduced result is: |
| total.device |
| # /job:localhost/replica:0/task:0/device:CPU:0 |
| |
| ``` |
| |
| This API is typically used for aggregating the results returned from |
| different replicas, for reporting etc. For example, loss computed from |
| different replicas can be averaged using this API before printing. |
| |
| Note: The result is copied to the "current" device - which would typically |
| be the CPU of the worker on which the program is running. For `TPUStrategy`, |
| it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, |
| this is CPU of each worker. |
| |
| There are a number of different tf.distribute APIs for reducing values |
| across replicas: |
| * `tf.distribute.ReplicaContext.all_reduce`: This differs from |
| `Strategy.reduce` in that it is for replica context and does |
| not copy the results to the host device. `all_reduce` should be typically |
| used for reductions inside the training step such as gradients. |
| * `tf.distribute.StrategyExtended.reduce_to` and |
| `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more |
| advanced versions of `Strategy.reduce` as they allow customizing the |
| destination of the result. They are also called in cross replica context. |
| |
| _What should axis be?_ |
| |
| Given a per-replica value returned by `run`, say a |
| per-example loss, the batch will be divided across all the replicas. This |
| function allows you to aggregate across replicas and optionally also across |
| batch elements by specifying the axis parameter accordingly. |
| |
| For example, if you have a global batch size of 8 and 2 |
| replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and |
| `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will |
| aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. |
| This is useful when each replica is computing a scalar or some other value |
| that doesn't have a "batch" dimension (like a gradient or loss). |
| ``` |
| strategy.reduce("sum", per_replica_result, axis=None) |
| ``` |
| |
| Sometimes, you will want to aggregate across both the global batch _and_ |
| all replicas. You can get this behavior by specifying the batch |
| dimension as the `axis`, typically `axis=0`. In this case it would return a |
| scalar `0+1+2+3+4+5+6+7`. |
| ``` |
| strategy.reduce("sum", per_replica_result, axis=0) |
| ``` |
| |
| If there is a last partial batch, you will need to specify an axis so |
| that the resulting shape is consistent across replicas. So if the last |
| batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you |
| would get a shape mismatch unless you specify `axis=0`. If you specify |
| `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct |
| denominator of 6. Contrast this with computing `reduce_mean` to get a |
| scalar value on each replica and this function to average those means, |
| which will weigh some values `1/8` and others `1/4`. |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` value specifying how values should |
| be combined. Allows using string representation of the enum such as |
| "SUM", "MEAN". |
| value: a `tf.distribute.DistributedValues` instance, e.g. returned by |
| `Strategy.run`, to be combined into a single tensor. It can also be a |
| regular tensor when used with `OneDeviceStrategy` or default strategy. |
| axis: specifies the dimension to reduce along within each |
| replica's tensor. Should typically be set to the batch dimension, or |
| `None` to only reduce across replicas (e.g. if the tensor has no batch |
| dimension). |
| |
| Returns: |
| A `Tensor`. |
| """ |
| # TODO(josh11b): support `value` being a nest. |
| _require_cross_replica_or_default_context_extended(self._extended) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if axis is None: |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op == reduce_util.ReduceOp.SUM: |
| |
| def reduce_sum(v): |
| return math_ops.reduce_sum(v, axis=axis) |
| |
| if eager_context.executing_eagerly(): |
| # As some strategies (e.g. TPUStrategy) doesn't support pure eager |
| # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be |
| # run from eager mode. Cache the tf.function by `axis` to avoid the |
| # same function to be traced again. |
| if axis not in self._reduce_sum_fns: |
| |
| def reduce_sum_fn(v): |
| return self.run(reduce_sum, args=(v,)) |
| |
| self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) |
| value = self._reduce_sum_fns[axis](value) |
| else: |
| value = self.run(reduce_sum, args=(value,)) |
| |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op != reduce_util.ReduceOp.MEAN: |
| raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " |
| "not: %r" % reduce_op) |
| # TODO(josh11b): Support list/tuple and tensor axis values. |
| if not isinstance(axis, six.integer_types): |
| raise TypeError("Expected `axis` to be an integer not: %r" % axis) |
| |
| def mean_reduce_helper(v, axis=axis): |
| """Computes the numerator and denominator on each replica.""" |
| numer = math_ops.reduce_sum(v, axis=axis) |
| if v.shape.rank is not None: |
| # Note(joshl): We support axis < 0 to be consistent with the |
| # tf.math.reduce_* operations. |
| if axis < 0: |
| if axis + v.shape.rank < 0: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| axis += v.shape.rank |
| elif axis >= v.shape.rank: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| # TF v2 returns `None` for unknown dimensions and an integer for |
| # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) |
| # or tensor_shape.Dimension(integer). `dimension_value` hides this |
| # difference, always returning `None` or an integer. |
| dim = tensor_shape.dimension_value(v.shape[axis]) |
| if dim is not None: |
| # By returning a python value in the static shape case, we can |
| # maybe get a fast path for reducing the denominator. |
| # TODO(b/151871486): Remove array_ops.identity after we fallback to |
| # simple reduction if inputs are all on CPU. |
| return numer, array_ops.identity( |
| constant_op.constant(dim, dtype=dtypes.int64)) |
| elif axis < 0: |
| axis = axis + array_ops.rank(v) |
| # TODO(b/151871486): Remove array_ops.identity after we fallback to simple |
| # reduction if inputs are all on CPU. |
| denom = array_ops.identity( |
| array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) |
| # TODO(josh11b): Should we cast denom to v.dtype here instead of after the |
| # reduce is complete? |
| return numer, denom |
| |
| if eager_context.executing_eagerly(): |
| # As some strategies (e.g. TPUStrategy) doesn't support pure eager |
| # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can |
| # be run from eager mode. Cache the tf.function by `axis` to avoid the |
| # same function to be traced again. |
| if axis not in self._mean_reduce_helper_fns: |
| |
| def mean_reduce_fn(v): |
| return self.run(mean_reduce_helper, args=(v,)) |
| |
| self._mean_reduce_helper_fns[axis] = def_function.function( |
| mean_reduce_fn) |
| numer, denom = self._mean_reduce_helper_fns[axis](value) |
| else: |
| numer, denom = self.run(mean_reduce_helper, args=(value,)) |
| |
| # TODO(josh11b): Should batch reduce here instead of doing two. |
| numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access |
| denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access |
| denom = math_ops.cast(denom, numer.dtype) |
| return math_ops.truediv(numer, denom) |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED |
| def unwrap(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| DEPRECATED: Please use `experimental_local_results` instead. |
| |
| Note: This only returns values on the workers initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, |
| `extended.call_for_each_replica()`, or a variable created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value`. If `value` represents a single |
| value, this returns `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| def experimental_local_results(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| Note: This only returns values on the worker initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, `run()`, |
| `extended.call_for_each_replica()`, or a variable created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value`. If `value` represents a single |
| value, this returns `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only |
| def group(self, value, name=None): |
| """Shortcut for `tf.group(self.experimental_local_results(value))`.""" |
| return self._extended._group(value, name) # pylint: disable=protected-access |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| return self._extended._num_replicas_in_sync # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string |
| def configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| # pylint: disable=g-doc-return-or-yield,g-doc-args |
| """DEPRECATED: use `update_config_proto` instead. |
| |
| Configures the strategy class. |
| |
| DEPRECATED: This method's functionality has been split into the strategy |
| constructor and `update_config_proto`. In the future, we will allow passing |
| cluster and config_proto to the constructor to configure the strategy. And |
| `update_config_proto` can be used to update the config_proto based on the |
| specific strategy. |
| """ |
| return self._extended._configure( # pylint: disable=protected-access |
| session_config, cluster_spec, task_type, task_id) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED |
| def update_config_proto(self, config_proto): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| def __deepcopy__(self, memo): |
| # First do a regular deepcopy of `self`. |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| for k, v in self.__dict__.items(): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| # One little fix-up: we want `result._extended` to reference `result` |
| # instead of `self`. |
| result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access |
| return result |
| |
| def __copy__(self): |
| raise RuntimeError("Must only deepcopy DistributionStrategy.") |
| |
| @property |
| def cluster_resolver(self): |
| """Returns the cluster resolver associated with this strategy. |
| |
| In general, when using a multi-worker `tf.distribute` strategy such as |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy` or |
| `tf.distribute.TPUStrategy()`, there is a |
| `tf.distribute.cluster_resolver.ClusterResolver` associated with the |
| strategy used, and such an instance is returned by this property. |
| |
| Strategies that intend to have an associated |
| `tf.distribute.cluster_resolver.ClusterResolver` must set the |
| relevant attribute, or override this property; otherwise, `None` is returned |
| by default. Those strategies should also provide information regarding what |
| is returned by this property. |
| |
| Single-worker strategies usually do not have a |
| `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this |
| property will return `None`. |
| |
| The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the |
| user needs to access information such as the cluster spec, task type or task |
| id. For example, |
| |
| ```python |
| |
| os.environ['TF_CONFIG'] = json.dumps({ |
| 'cluster': { |
| 'worker': ["localhost:12345", "localhost:23456"], |
| 'ps': ["localhost:34567"] |
| }, |
| 'task': {'type': 'worker', 'index': 0} |
| }) |
| |
| # This implicitly uses TF_CONFIG for the cluster and current task info. |
| strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() |
| |
| ... |
| |
| if strategy.cluster_resolver.task_type == 'worker': |
| # Perform something that's only applicable on workers. Since we set this |
| # as a worker above, this block will run on this particular instance. |
| elif strategy.cluster_resolver.task_type == 'ps': |
| # Perform something that's only applicable on parameter servers. Since we |
| # set this as a worker above, this block will not run on this particular |
| # instance. |
| ``` |
| |
| For more information, please see |
| `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. |
| |
| Returns: |
| The cluster resolver associated with this strategy. Returns `None` if a |
| cluster resolver is not applicable or available in this strategy. |
| """ |
| if hasattr(self.extended, "_cluster_resolver"): |
| return self.extended._cluster_resolver # pylint: disable=protected-access |
| return None |
| |
| |
| @tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring |
| class Strategy(StrategyBase): |
| |
| __doc__ = StrategyBase.__doc__ |
| |
| def experimental_assign_to_logical_device(self, tensor, logical_device_id): |
| """Adds annotation that `tensor` will be assigned to a logical device. |
| |
| NOTE: This API is only supported in TPUStrategy for now. |
| This adds an annotation to `tensor` specifying that operations on |
| `tensor` will be invoked on logical core device id `logical_device_id`. |
| When model parallelism is used, the default behavior is that all ops |
| are placed on zero-th logical device. |
| |
| ```python |
| |
| # Initializing TPU system with 2 logical devices and 4 replicas. |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tf.config.experimental_connect_to_cluster(resolver) |
| topology = tf.tpu.experimental.initialize_tpu_system(resolver) |
| device_assignment = tf.tpu.experimental.DeviceAssignment.build( |
| topology, |
| computation_shape=[1, 1, 1, 2], |
| num_replicas=4) |
| strategy = tf.distribute.TPUStrategy( |
| resolver, experimental_device_assignment=device_assignment) |
| iterator = iter(inputs) |
| |
| @tf.function() |
| def step_fn(inputs): |
| output = tf.add(inputs, inputs) |
| |
| # Add operation will be executed on logical device 0. |
| output = strategy.experimental_assign_to_logical_device(output, 0) |
| return output |
| |
| strategy.run(step_fn, args=(next(iterator),)) |
| ``` |
| |
| Args: |
| tensor: Input tensor to annotate. |
| logical_device_id: Id of the logical core to which the tensor will be |
| assigned. |
| |
| Raises: |
| ValueError: The logical device id presented is not consistent with total |
| number of partitions specified by the device assignment. |
| |
| Returns: |
| Annotated tensor with idential value as `tensor`. |
| """ |
| return self._extended._experimental_assign_to_logical_device( # pylint: disable=protected-access |
| tensor, logical_device_id) |
| |
| def experimental_split_to_logical_devices(self, tensor, partition_dimensions): |
| """Adds annotation that `tensor` will be split across logical devices. |
| |
| NOTE: This API is only supported in TPUStrategy for now. |
| This adds an annotation to tensor `tensor` specifying that operations on |
| `tensor` will be be split among multiple logical devices. Tensor `tensor` |
| will be split across dimensions specified by `partition_dimensions`. |
| The dimensions of `tensor` must be divisible by corresponding value in |
| `partition_dimensions`. |
| |
| For example, for system with 8 logical devices, if `tensor` is an image |
| tensor with shape (batch_size, width, height, channel) and |
| `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split |
| 2 in width dimension and 4 way in height dimension and the split |
| tensor values will be fed into 8 logical devices. |
| |
| ```python |
| # Initializing TPU system with 8 logical devices and 1 replica. |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tf.config.experimental_connect_to_cluster(resolver) |
| topology = tf.tpu.experimental.initialize_tpu_system(resolver) |
| device_assignment = tf.tpu.experimental.DeviceAssignment.build( |
| topology, |
| computation_shape=[1, 2, 2, 2], |
| num_replicas=1) |
| strategy = tf.distribute.TPUStrategy( |
| resolver, experimental_device_assignment=device_assignment) |
| |
| iterator = iter(inputs) |
| |
| @tf.function() |
| def step_fn(inputs): |
| inputs = strategy.experimental_split_to_logical_devices( |
| inputs, [1, 2, 4, 1]) |
| |
| # model() function will be executed on 8 logical devices with `inputs` |
| # split 2 * 4 ways. |
| output = model(inputs) |
| return output |
| |
| strategy.run(step_fn, args=(next(iterator),)) |
| ``` |
| Args: |
| tensor: Input tensor to annotate. |
| partition_dimensions: An unnested list of integers with the size equal to |
| rank of `tensor` specifying how `tensor` will be partitioned. The |
| product of all elements in `partition_dimensions` must be equal to the |
| total number of logical devices per replica. |
| |
| Raises: |
| ValueError: 1) If the size of partition_dimensions does not equal to rank |
| of `tensor` or 2) if product of elements of `partition_dimensions` does |
| not match the number of logical devices per replica defined by the |
| implementing DistributionStrategy's device specification or |
| 3) if a known size of `tensor` is not divisible by corresponding |
| value in `partition_dimensions`. |
| |
| Returns: |
| Annotated tensor with idential value as `tensor`. |
| """ |
| return self._extended._experimental_split_to_logical_devices( # pylint: disable=protected-access |
| tensor, partition_dimensions) |
| |
| def experimental_replicate_to_logical_devices(self, tensor): |
| """Adds annotation that `tensor` will be replicated to all logical devices. |
| |
| NOTE: This API is only supported in TPUStrategy for now. |
| This adds an annotation to tensor `tensor` specifying that operations on |
| `tensor` will be invoked on all logical devices. |
| |
| ```python |
| # Initializing TPU system with 2 logical devices and 4 replicas. |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') |
| tf.config.experimental_connect_to_cluster(resolver) |
| topology = tf.tpu.experimental.initialize_tpu_system(resolver) |
| device_assignment = tf.tpu.experimental.DeviceAssignment.build( |
| topology, |
| computation_shape=[1, 1, 1, 2], |
| num_replicas=4) |
| strategy = tf.distribute.TPUStrategy( |
| resolver, experimental_device_assignment=device_assignment) |
| |
| iterator = iter(inputs) |
| |
| @tf.function() |
| def step_fn(inputs): |
| images, labels = inputs |
| images = strategy.experimental_split_to_logical_devices( |
| inputs, [1, 2, 4, 1]) |
| |
| # model() function will be executed on 8 logical devices with `inputs` |
| # split 2 * 4 ways. |
| output = model(inputs) |
| |
| # For loss calculation, all logical devices share the same logits |
| # and labels. |
| labels = strategy.experimental_replicate_to_logical_devices(labels) |
| output = strategy.experimental_replicate_to_logical_devices(output) |
| loss = loss_fn(labels, output) |
| |
| return loss |
| |
| strategy.run(step_fn, args=(next(iterator),)) |
| ``` |
| Args: |
| tensor: Input tensor to annotate. |
| |
| Returns: |
| Annotated tensor with idential value as `tensor`. |
| """ |
| return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access |
| |
| def experimental_distribute_values_from_function(self, value_fn): |
| """Generates `tf.distribute.DistributedValues` from `value_fn`. |
| |
| This function is to generate `tf.distribute.DistributedValues` to pass |
| into `run`, `reduce`, or other methods that take |
| distributed values when not using datasets. |
| |
| Args: |
| value_fn: The function to run to generate values. It is called for |
| each replica with `tf.distribute.ValueContext` as the sole argument. It |
| must return a Tensor or a type that can be converted to a Tensor. |
| Returns: |
| A `tf.distribute.DistributedValues` containing a value for each replica. |
| |
| Example usage: |
| |
| 1. Return constant value per replica: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(ctx): |
| ... return tf.constant(1.) |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, |
| <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) |
| |
| 2. Distribute values in array based on replica_id: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> array_value = np.array([3., 2., 1.]) |
| >>> def value_fn(ctx): |
| ... return array_value[ctx.replica_id_in_sync_group] |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (3.0, 2.0) |
| |
| 3. Specify values using num_replicas_in_sync: |
| |
| >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) |
| >>> def value_fn(ctx): |
| ... return ctx.num_replicas_in_sync |
| >>> distributed_values = ( |
| ... strategy.experimental_distribute_values_from_function( |
| ... value_fn)) |
| >>> local_result = strategy.experimental_local_results(distributed_values) |
| >>> local_result |
| (2, 2) |
| |
| 4. Place values on devices and distribute: |
| |
| ``` |
| strategy = tf.distribute.TPUStrategy() |
| worker_devices = strategy.extended.worker_devices |
| multiple_values = [] |
| for i in range(strategy.num_replicas_in_sync): |
| with tf.device(worker_devices[i]): |
| multiple_values.append(tf.constant(1.0)) |
| |
| def value_fn(ctx): |
| return multiple_values[ctx.replica_id_in_sync_group] |
| |
| distributed_values = strategy. |
| experimental_distribute_values_from_function( |
| value_fn) |
| ``` |
| |
| """ |
| return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access |
| value_fn) |
| |
| |
| # TF v1.x version has additional deprecated APIs |
| @tf_export(v1=["distribute.Strategy"]) |
| class StrategyV1(StrategyBase): |
| """A list of devices with a state & compute distribution policy. |
| |
| See [the guide](https://www.tensorflow.org/guide/distribute_strategy) |
| for overview and examples. |
| |
| Note: Not all `tf.distribute.Strategy` implementations currently support |
| TensorFlow's partitioned variables (where a single variable is split across |
| multiple devices) at this time. |
| """ |
| |
| def make_dataset_iterator(self, dataset): |
| """Makes an iterator for input provided via `dataset`. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| Data from the given dataset will be distributed evenly across all the |
| compute replicas. We will assume that the input dataset is batched by the |
| global batch size. With this assumption, we will make a best effort to |
| divide each batch across all the replicas (one or more workers). |
| If this effort fails, an error will be thrown, and the user should instead |
| use `make_input_fn_iterator` which provides more control to the user, and |
| does not try to divide a batch across replicas. |
| |
| The user could also use `make_input_fn_iterator` if they want to |
| customize which input is fed to which replica/worker etc. |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be distributed evenly across all |
| replicas. |
| |
| Returns: |
| An `tf.distribute.InputIterator` which returns inputs for each step of the |
| computation. User should call `initialize` on the returned iterator. |
| """ |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """Returns an iterator split across replicas created from an input function. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The `input_fn` should take an `tf.distribute.InputContext` object where |
| information about batching and input sharding can be accessed: |
| |
| ``` |
| def input_fn(input_context): |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) |
| return d.shard(input_context.num_input_pipelines, |
| input_context.input_pipeline_id) |
| with strategy.scope(): |
| iterator = strategy.make_input_fn_iterator(input_fn) |
| replica_results = strategy.experimental_run(replica_fn, iterator) |
| ``` |
| |
| The `tf.data.Dataset` returned by `input_fn` should have a per-replica |
| batch size, which may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Args: |
| input_fn: A function taking a `tf.distribute.InputContext` object and |
| returning a `tf.data.Dataset`. |
| replication_mode: an enum value of `tf.distribute.InputReplicationMode`. |
| Only `PER_WORKER` is supported currently, which means there will be |
| a single call to `input_fn` per worker. Replicas will dequeue from the |
| local `tf.data.Dataset` on their worker. |
| |
| Returns: |
| An iterator object that should first be `.initialize()`-ed. It may then |
| either be passed to `strategy.experimental_run()` or you can |
| `iterator.get_next()` to get the next value to pass to |
| `strategy.extended.call_for_each_replica()`. |
| """ |
| return super(StrategyV1, self).make_input_fn_iterator( |
| input_fn, replication_mode) |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a tf.data.Dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Note that you will likely need to use |
| tf.distribute.Strategy.experimental_distribute_dataset |
| with the returned dataset to further distribute it with the strategy. |
| |
| Example: |
| ``` |
| numpy_input = np.ones([10], dtype=np.float32) |
| dataset = strategy.experimental_make_numpy_dataset(numpy_input) |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| ``` |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be converted into a |
| dataset. Note that lists of Numpy arrays are stacked, as that is normal |
| `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| return self.extended.experimental_make_numpy_dataset( |
| numpy_input, session=session) |
| |
| def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation |
| """Runs ops in `fn` on each replica, with inputs from `input_iterator`. |
| |
| DEPRECATED: This method is not available in TF 2.x. Please switch |
| to using `run` instead. |
| |
| When eager execution is enabled, executes ops specified by `fn` on each |
| replica. Otherwise, builds a graph to execute the ops on each replica. |
| |
| Each replica will take a single, different input from the inputs provided by |
| one `get_next` call on the input iterator. |
| |
| `fn` may call `tf.distribute.get_replica_context()` to access members such |
| as `replica_id_in_sync_group`. |
| |
| IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being |
| used, and whether eager execution is enabled, `fn` may be called one or more |
| times (once for each replica). |
| |
| Args: |
| fn: The function to run. The inputs to the function must match the outputs |
| of `input_iterator.get_next()`. The output must be a `tf.nest` of |
| `Tensor`s. |
| input_iterator: (Optional) input iterator from which the inputs are taken. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be `PerReplica` (if the values are unsynchronized), |
| `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a |
| single replica). |
| """ |
| return super(StrategyV1, self).experimental_run( |
| fn, input_iterator) |
| |
| def reduce(self, reduce_op, value, axis=None): |
| return super(StrategyV1, self).reduce(reduce_op, value, axis) |
| |
| reduce.__doc__ = StrategyBase.reduce.__doc__ |
| |
| def update_config_proto(self, config_proto): |
| """Returns a copy of `config_proto` modified for use with this strategy. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The updated config has something needed to run a strategy, e.g. |
| configuration to run collective ops, or device filters to improve |
| distributed training performance. |
| |
| Args: |
| config_proto: a `tf.ConfigProto` object. |
| |
| Returns: |
| The updated copy of the `config_proto`. |
| """ |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| |
| # NOTE(josh11b): For any strategy that needs to support tf.compat.v1, |
| # instead descend from StrategyExtendedV1. |
| @tf_export("distribute.StrategyExtended", v1=[]) |
| class StrategyExtendedV2(object): |
| """Additional APIs for algorithms that need to be distribution-aware. |
| |
| Note: For most usage of `tf.distribute.Strategy`, there should be no need to |
| call these methods, since TensorFlow libraries (such as optimizers) already |
| call these methods when needed on your behalf. |
| |
| |
| Some common use cases of functions on this page: |
| |
| * _Locality_ |
| |
| `tf.distribute.DistributedValues` can have the same _locality_ as a |
| _distributed variable_, which leads to a mirrored value residing on the same |
| devices as the variable (as opposed to the compute devices). Such values may |
| be passed to a call to `tf.distribute.StrategyExtended.update` to update the |
| value of a variable. You may use |
| `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the |
| same locality as another variable. You may convert a "PerReplica" value to a |
| variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or |
| `tf.distribute.StrategyExtended.batch_reduce_to`. |
| |
| * _How to update a distributed variable_ |
| |
| A distributed variable is variables created on multiple devices. As discussed |
| in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), |
| mirrored variable and SyncOnRead variable are two examples. The standard |
| pattern for updating distributed variables is to: |
| |
| 1. In your function passed to `tf.distribute.Strategy.run`, |
| compute a list of (update, variable) pairs. For example, the update might |
| be a gradient of the loss with respect to the variable. |
| 2. Switch to cross-replica mode by calling |
| `tf.distribute.get_replica_context().merge_call()` with the updates and |
| variables as arguments. |
| 3. Call |
| `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` |
| (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` |
| (for a list of variables) to sum the updates. |
| 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update |
| its value. |
| |
| Steps 2 through 4 are done automatically by class |
| `tf.keras.optimizers.Optimizer` if you call its |
| `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. |
| |
| In fact, a higher-level solution to update a distributed variable is by |
| calling `assign` on the variable as you would do to a regular `tf.Variable`. |
| You can call the method in both _replica context_ and _cross-replica context_. |
| For a _mirrored variable_, calling `assign` in _replica context_ requires you |
| to specify the `aggregation` type in the variable constructor. In that case, |
| the context switching and sync described in steps 2 through 4 are handled for |
| you. If you call `assign` on _mirrored variable_ in _cross-replica context_, |
| you can only assign a single value or assign values from another mirrored |
| variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead |
| variable_, in _replica context_, you can simply call `assign` on it and no |
| aggregation happens under the hood. In _cross-replica context_, you can only |
| assign a single value to a SyncOnRead variable. One example case is restoring |
| from a checkpoint: if the `aggregation` type of the variable is |
| `tf.VariableAggregation.SUM`, it is assumed that replica values were added |
| before checkpointing, so at the time of restoring, the value is divided by |
| the number of replicas and then assigned to each replica; if the `aggregation` |
| type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica |
| directly. |
| |
| """ |
| |
| def __init__(self, container_strategy): |
| self._container_strategy_weakref = weakref.ref(container_strategy) |
| self._default_device = None |
| # This property is used to determine if we should set drop_remainder=True |
| # when creating Datasets from numpy array inputs. |
| self._require_static_shapes = False |
| |
| def _container_strategy(self): |
| """Get the containing `tf.distribute.Strategy`. |
| |
| This should not generally be needed except when creating a new |
| `ReplicaContext` and to validate that the caller is in the correct |
| `scope()`. |
| |
| Returns: |
| The `tf.distribute.Strategy` such that `strategy.extended` is `self`. |
| """ |
| container_strategy = self._container_strategy_weakref() |
| assert container_strategy is not None |
| return container_strategy |
| |
| def _scope(self, strategy): |
| """Implementation of tf.distribute.Strategy.scope().""" |
| |
| def creator_with_resource_vars(next_creator, **kwargs): |
| """Variable creator to use in `_CurrentDistributionContext`.""" |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["distribute_strategy"] = strategy |
| |
| # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid |
| # dereferencing a `Tensor` that is without a `name`. We still need to |
| # propagate the metadata it's holding. |
| if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): |
| checkpoint_restore_uid = kwargs[ |
| "initial_value"].checkpoint_position.restore_uid |
| kwargs["initial_value"] = kwargs["initial_value"].wrapped_value |
| else: |
| checkpoint_restore_uid = None |
| |
| created = self._create_variable(next_creator, **kwargs) |
| |
| if checkpoint_restore_uid is not None: |
| # pylint: disable=protected-access |
| # Let the checkpointing infrastructure know that the variable was |
| # already restored so it doesn't waste memory loading the value again. |
| created._maybe_initialize_trackable() |
| created._update_uid = checkpoint_restore_uid |
| # pylint: enable=protected-access |
| return created |
| |
| def distributed_getter(getter, *args, **kwargs): |
| if not self._allow_variable_partition(): |
| if kwargs.pop("partitioner", None) is not None: |
| tf_logging.log_first_n( |
| tf_logging.WARN, "Partitioned variables are disabled when using " |
| "current tf.distribute.Strategy.", 1) |
| return getter(*args, **kwargs) |
| |
| return _CurrentDistributionContext( |
| strategy, |
| variable_scope.variable_creator_scope(creator_with_resource_vars), |
| variable_scope.variable_scope( |
| variable_scope.get_variable_scope(), |
| custom_getter=distributed_getter), self._default_device) |
| |
| def _allow_variable_partition(self): |
| return False |
| |
| def _create_variable(self, next_creator, **kwargs): |
| # Note: should support "colocate_with" argument. |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def variable_created_in_scope(self, v): |
| """Tests whether `v` was created while this strategy scope was active. |
| |
| Variables created inside the strategy scope are "owned" by it: |
| |
| >>> strategy = tf.distribute.MirroredStrategy() |
| >>> with strategy.scope(): |
| ... v = tf.Variable(1.) |
| >>> strategy.extended.variable_created_in_scope(v) |
| True |
| |
| Variables created outside the strategy are not owned by it: |
| |
| >>> strategy = tf.distribute.MirroredStrategy() |
| >>> v = tf.Variable(1.) |
| >>> strategy.extended.variable_created_in_scope(v) |
| False |
| |
| Args: |
| v: A `tf.Variable` instance. |
| |
| Returns: |
| True if `v` was created inside the scope, False if not. |
| """ |
| return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Scope that controls which devices variables will be created on. |
| |
| No operations should be added to the graph inside this scope, it |
| should only be used when creating variables (some implementations |
| work by changing variable creation, others work by using a |
| tf.compat.v1.colocate_with() scope). |
| |
| This may only be used inside `self.scope()`. |
| |
| Example usage: |
| |
| ``` |
| with strategy.scope(): |
| var1 = tf.Variable(...) |
| with strategy.extended.colocate_vars_with(var1): |
| # var2 and var3 will be created on the same device(s) as var1 |
| var2 = tf.Variable(...) |
| var3 = tf.Variable(...) |
| |
| def fn(v1, v2, v3): |
| # operates on v1 from var1, v2 from var2, and v3 from var3 |
| |
| # `fn` runs on every device `var1` is on, `var2` and `var3` will be there |
| # too. |
| strategy.extended.update(var1, fn, args=(var2, var3)) |
| ``` |
| |
| Args: |
| colocate_with_variable: A variable created in this strategy's `scope()`. |
| Variables created while in the returned context manager will be on the |
| same set of devices as `colocate_with_variable`. |
| |
| Returns: |
| A context manager. |
| """ |
| |
| def create_colocated_variable(next_creator, **kwargs): |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["colocate_with"] = colocate_with_variable |
| return next_creator(**kwargs) |
| |
| _require_strategy_scope_extended(self) |
| self._validate_colocate_with_variable(colocate_with_variable) |
| return variable_scope.variable_creator_scope(create_colocated_variable) |
| |
| def _validate_colocate_with_variable(self, colocate_with_variable): |
| """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" |
| pass |
| |
| def _experimental_assign_to_logical_device(self, tensor, logical_device_id): |
| raise NotImplementedError("This method should be overriden by " |
| "sub-classes which support model parallelism.") |
| |
| def _experimental_split_to_logical_devices(self, tensor, |
| partition_dimensions): |
| raise NotImplementedError("This method should be overriden by " |
| "sub-classes which support model parallelism.") |
| |
| def _experimental_replicate_to_logical_devices(self, tensor): |
| raise NotImplementedError("This method should be overriden by " |
| "sub-classes which support model parallelism.") |
| |
| def _make_dataset_iterator(self, dataset): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _make_input_fn_iterator(self, input_fn, replication_mode): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_datasets_from_function(self, dataset_fn, |
| options): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _reduce(self, reduce_op, value): |
| # Default implementation until we have an implementation for each strategy. |
| dst = device_util.current() or self._default_device or "/device:CPU:0" |
| return self._local_results(self.reduce_to(reduce_op, value, dst))[0] |
| |
| def reduce_to(self, reduce_op, value, destinations, experimental_hints=None): |
| """Combine (via e.g. sum or mean) values across replicas. |
| |
| `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed |
| variables. It supports both dense values and `tf.IndexedSlices`. |
| |
| This API currently can only be called in cross-replica context. Other |
| variants to reduce values across replicas are: |
| * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of |
| this API. |
| * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API |
| in replica context. It supports both batched and non-batched all-reduce. |
| * `tf.distribute.Strategy.reduce`: a more convenient method to reduce |
| to the host in cross-replica context. |
| |
| `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can |
| also pass in a `Tensor`, and the destinations will be the device of that |
| tensor. For all-reduce, pass the same to `value` and `destinations`. |
| |
| It can be used in `tf.distribute.ReplicaContext.merge_call` to write code |
| that works for all `tf.distribute.Strategy`. |
| |
| >>> @tf.function |
| ... def step_fn(var): |
| ... |
| ... def merge_fn(strategy, value, var): |
| ... # All-reduce the value. Note that `value` here is a |
| ... # `tf.distribute.DistributedValues`. |
| ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, |
| ... value, destinations=var) |
| ... strategy.extended.update(var, lambda var, value: var.assign(value), |
| ... args=(reduced,)) |
| ... |
| ... value = tf.identity(1.) |
| ... tf.distribute.get_replica_context().merge_call(merge_fn, |
| ... args=(value, var)) |
| >>> |
| >>> def run(strategy): |
| ... with strategy.scope(): |
| ... v = tf.Variable(0.) |
| ... strategy.run(step_fn, args=(v,)) |
| ... return v |
| >>> |
| >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> |
| } |
| >>> run(tf.distribute.experimental.CentralStorageStrategy( |
| ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> |
| >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp` or string. How to reduce the value. |
| value: a `tf.distribute.DistributedValue`, or a `tf.Tensor` like object. |
| destinations: a `tf.distribute.DistributedValue`, a `tf.Variable`, a |
| `tf.Tensor` alike object, or a device string. It specifies the devices |
| to reduce to. To perform an all-reduce, pass the same to `value` and |
| `destinations`. Note that if it's a `tf.Variable`, the value is reduced |
| to the devices of that variable, this method doesn't update the variable. |
| experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints |
| to perform collective operations. See |
| `tf.distrbute.experimental.CollectiveHints` for details. |
| |
| Returns: |
| A tensor or value reduced to `destinations`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| assert (reduce_op == reduce_util.ReduceOp.SUM or |
| reduce_op == reduce_util.ReduceOp.MEAN) |
| if experimental_hints is None: |
| experimental_hints = collective_util.Hints() |
| return self._reduce_to(reduce_op, value, destinations, experimental_hints) |
| |
| def _reduce_to(self, reduce_op, value, destinations, experimental_hints): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def batch_reduce_to(self, |
| reduce_op, |
| value_destination_pairs, |
| experimental_hints=None): |
| """Combine multiple `reduce_to` calls into one for faster execution. |
| |
| Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. |
| It's more efficient than reduce each value separately. |
| |
| This API currently can only be called in cross-replica context. Other |
| variants to reduce values across replicas are: |
| * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of |
| this API. |
| * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API |
| in replica context. It supports both batched and non-batched all-reduce. |
| * `tf.distribute.Strategy.reduce`: a more convenient method to reduce |
| to the host in cross-replica context. |
| |
| See `reduce_to` for more information. |
| |
| >>> @tf.function |
| ... def step_fn(var): |
| ... |
| ... def merge_fn(strategy, value, var): |
| ... # All-reduce the value. Note that `value` here is a |
| ... # `tf.distribute.DistributedValues`. |
| ... reduced = strategy.extended.batch_reduce_to( |
| ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] |
| ... strategy.extended.update(var, lambda var, value: var.assign(value), |
| ... args=(reduced,)) |
| ... |
| ... value = tf.identity(1.) |
| ... tf.distribute.get_replica_context().merge_call(merge_fn, |
| ... args=(value, var)) |
| >>> |
| >>> def run(strategy): |
| ... with strategy.scope(): |
| ... v = tf.Variable(0.) |
| ... strategy.run(step_fn, args=(v,)) |
| ... return v |
| >>> |
| >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) |
| MirroredVariable:{ |
| 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, |
| 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> |
| } |
| >>> run(tf.distribute.experimental.CentralStorageStrategy( |
| ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> |
| >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) |
| <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> |
| |
| Args: |
| reduce_op: a `tf.distribute.ReduceOp`. How to reduce the value. |
| value_destination_pairs: a sequence of (value, destinations) pairs. See |
| `reduce_to()` for descriptions. |
| experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints |
| to perform collective operations. |
| |
| Returns: |
| A list of reduced values, one per pair in `value_destination_pairs`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if experimental_hints is None: |
| experimental_hints = collective_util.Hints() |
| return self._batch_reduce_to(reduce_op, value_destination_pairs, |
| experimental_hints) |
| |
| def _batch_reduce_to(self, reduce_op, value_destination_pairs, |
| experimental_hints): |
| return [ |
| self.reduce_to( |
| reduce_op, t, destinations=v, experimental_hints=experimental_hints) |
| for t, v in value_destination_pairs |
| ] |
| |
| def update(self, var, fn, args=(), kwargs=None, group=True): |
| """Run `fn` to update `var` using inputs mirrored to the same devices. |
| |
| `tf.distribute.StrategyExtended.update` takes a distributed variable `var` |
| to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It |
| applies `fn` to each component variable of `var` and passes corresponding |
| values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain |
| per-replica values. If they contain mirrored values, they will be unwrapped |
| before calling `fn`. For example, `fn` can be `assign_add` and `args` can be |
| a mirrored DistributedValues where each component contains the value to be |
| added to this mirrored variable `var`. Calling `update` will call |
| `assign_add` on each component variable of `var` with the corresponding |
| tensor value on that device. |
| |
| Example usage: |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices |
| with strategy.scope(): |
| v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) |
| def update_fn(v): |
| return v.assign(1.0) |
| result = strategy.extended.update(v, update_fn) |
| # result is |
| # Mirrored:{ |
| # 0: tf.Tensor(1.0, shape=(), dtype=float32), |
| # 1: tf.Tensor(1.0, shape=(), dtype=float32) |
| # } |
| ``` |
| |
| If `var` is mirrored across multiple devices, then this method implements |
| logic as following: |
| |
| ```python |
| results = {} |
| for device, v in var: |
| with tf.device(device): |
| # args and kwargs will be unwrapped if they are mirrored. |
| results[device] = fn(v, *args, **kwargs) |
| return merged(results) |
| ``` |
| |
| Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with |
| `var`. |
| |
| Args: |
| var: Variable, possibly mirrored to multiple devices, to operate on. |
| fn: Function to call. Should take the variable as the first argument. |
| args: Tuple or list. Additional positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| By default, the merged return value of `fn` across all replicas. The |
| merged result has dependencies to make sure that if it is evaluated at |
| all, the side effects (updates) will happen on every replica. If instead |
| "group=False" is specified, this function will return a nest of lists |
| where each list has an element per replica, and the caller is responsible |
| for ensuring all elements are executed. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| with self._container_strategy().scope(): |
| return self._update(var, fn, args, kwargs, group) |
| |
| def _update(self, var, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @doc_controls.do_not_generate_docs |
| def update_non_slot( |
| self, colocate_with, fn, args=(), kwargs=None, group=True): |
| """Runs `fn(*args, **kwargs)` on `colocate_with` devices. |
| |
| Used to update non-slot variables. |
| |
| Args: |
| colocate_with: Devices returned by `non_slot_devices()`. |
| fn: Function to execute. |
| args: Tuple or list. Positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| Return value of `fn`, possibly merged across devices. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| fn = autograph.tf_convert( |
| fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| with self._container_strategy().scope(): |
| return self._update_non_slot(colocate_with, fn, args, kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _local_results(self, distributed_value): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def value_container(self, value): |
| """Returns the container that this per-replica `value` belongs to. |
| |
| Args: |
| value: A value returned by `run()` or a variable created in `scope()`. |
| |
| Returns: |
| A container that `value` belongs to. |
| If value does not belong to any container (including the case of |
| container having been destroyed), returns the value itself. |
| `value in experimental_local_results(value_container(value))` will |
| always be true. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _group(self, value, name=None): |
| """Implementation of `group`.""" |
| value = nest.flatten(self._local_results(value)) |
| |
| if len(value) != 1 or name is not None: |
| return control_flow_ops.group(value, name=name) |
| # Special handling for the common case of one op. |
| v, = value |
| if hasattr(v, "op"): |
| v = v.op |
| return v |
| |
| @property |
| def experimental_require_static_shapes(self): |
| """Returns `True` if static shape is required; `False` otherwise.""" |
| return self._require_static_shapes |
| |
| @property |
| def _num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def worker_devices(self): |
| """Returns the tuple of all devices used to for compute replica execution. |
| """ |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def parameter_devices(self): |
| """Returns the tuple of all devices used to place variables.""" |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @doc_controls.do_not_generate_docs |
| def non_slot_devices(self, var_list): |
| """Device(s) for non-slot variables. |
| |
| This method returns non-slot devices where non-slot variables are placed. |
| Users can create non-slot variables on these devices by using a block: |
| |
| ```python |
| with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): |
| ... |
| ``` |
| |
| Args: |
| var_list: The list of variables being optimized, needed with the |
| default `tf.distribute.Strategy`. |
| Returns: |
| A sequence of devices for non-slot variables. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| """Configures the strategy class.""" |
| del session_config, cluster_spec, task_type, task_id |
| |
| def _update_config_proto(self, config_proto): |
| return copy.deepcopy(config_proto) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings. |
| |
| Multi-worker training refers to the setup where the training is |
| distributed across multiple workers, as opposed to the case where |
| only a local process performs the training. This function is |
| used by higher-level apis such as Keras' `model.fit()` to infer |
| for example whether or not a distribute coordinator should be run, |
| and thus TensorFlow servers should be started for communication |
| with other servers in the cluster, or whether or not saving/restoring |
| checkpoints is relevant for preemption fault tolerance. |
| |
| Subclasses should override this to provide whether the strategy is |
| currently in multi-worker setup. |
| |
| Experimental. Signature and implementation are subject to change. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| @tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring |
| class StrategyExtendedV1(StrategyExtendedV2): |
| |
| __doc__ = StrategyExtendedV2.__doc__ |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be distributed evenly |
| across all replicas. Note that lists of Numpy arrays are stacked, as |
| that is normal `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| return self._experimental_make_numpy_dataset(numpy_input, session=session) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def broadcast_to(self, tensor, destinations): |
| """Mirror a tensor on one device to all worker devices. |
| |
| Args: |
| tensor: A Tensor value to broadcast. |
| destinations: A mirrored variable or device string specifying the |
| destination devices to copy `tensor` to. |
| |
| Returns: |
| A value mirrored to `destinations` devices. |
| """ |
| assert destinations is not None # from old strategy.broadcast() |
| # TODO(josh11b): More docstring |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| return self._broadcast_to(tensor, destinations) |
| |
| def _broadcast_to(self, tensor, destinations): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def experimental_run_steps_on_iterator(self, |
| fn, |
| iterator, |
| iterations=1, |
| initial_loop_values=None): |
| """DEPRECATED: please use `run` instead. |
| |
| Run `fn` with input from `iterator` for `iterations` times. |
| |
| This method can be used to run a step function for training a number of |
| times using input from a dataset. |
| |
| Args: |
| fn: function to run using this distribution strategy. The function must |
| have the following signature: `def fn(context, inputs)`. `context` is an |
| instance of `MultiStepContext` that will be passed when `fn` is run. |
| `context` can be used to specify the outputs to be returned from `fn` |
| by calling `context.set_last_step_output`. It can also be used to |
| capture non tensor outputs by `context.set_non_tensor_output`. See |
| `MultiStepContext` documentation for more information. `inputs` will |
| have same type/structure as `iterator.get_next()`. Typically, `fn` |
| will use `call_for_each_replica` method of the strategy to distribute |
| the computation over multiple replicas. |
| iterator: Iterator of a dataset that represents the input for `fn`. The |
| caller is responsible for initializing the iterator as needed. |
| iterations: (Optional) Number of iterations that `fn` should be run. |
| Defaults to 1. |
| initial_loop_values: (Optional) Initial values to be passed into the |
| loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove |
| initial_loop_values argument when we have a mechanism to infer the |
| outputs of `fn`. |
| |
| Returns: |
| Returns the `MultiStepContext` object which has the following properties, |
| among other things: |
| - run_op: An op that runs `fn` `iterations` times. |
| - last_step_outputs: A dictionary containing tensors set using |
| `context.set_last_step_output`. Evaluating this returns the value of |
| the tensors after the last iteration. |
| - non_tensor_outputs: A dictionary containing anything that was set by |
| `fn` by calling `context.set_non_tensor_output`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| with self._container_strategy().scope(): |
| return self._experimental_run_steps_on_iterator(fn, iterator, iterations, |
| initial_loop_values) |
| |
| def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, |
| initial_loop_values): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def call_for_each_replica(self, fn, args=(), kwargs=None): |
| """Run `fn` once per replica. |
| |
| `fn` may call `tf.get_replica_context()` to access methods such as |
| `replica_id_in_sync_group` and `merge_call()`. |
| |
| `merge_call()` is used to communicate between the replicas and |
| re-enter the cross-replica context. All replicas pause their execution |
| having encountered a `merge_call()` call. After that the |
| `merge_fn`-function is executed. Its results are then unwrapped and |
| given back to each replica call. After that execution resumes until |
| `fn` is complete or encounters another `merge_call()`. Example: |
| |
| ```python |
| # Called once in "cross-replica" context. |
| def merge_fn(distribution, three_plus_replica_id): |
| # sum the values across replicas |
| return sum(distribution.experimental_local_results(three_plus_replica_id)) |
| |
| # Called once per replica in `distribution`, in a "replica" context. |
| def fn(three): |
| replica_ctx = tf.get_replica_context() |
| v = three + replica_ctx.replica_id_in_sync_group |
| # Computes the sum of the `v` values across all replicas. |
| s = replica_ctx.merge_call(merge_fn, args=(v,)) |
| return s + v |
| |
| with distribution.scope(): |
| # in "cross-replica" context |
| ... |
| merged_results = distribution.run(fn, args=[3]) |
| # merged_results has the values from every replica execution of `fn`. |
| # This statement prints a list: |
| print(distribution.experimental_local_results(merged_results)) |
| ``` |
| |
| Args: |
| fn: function to run (will be run once per replica). |
| args: Tuple or list with positional arguments for `fn`. |
| kwargs: Dict with keyword arguments for `fn`. |
| |
| Returns: |
| Merged return value of `fn` across all replicas. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| with self._container_strategy().scope(): |
| return self._call_for_each_replica(fn, args, kwargs) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def read_var(self, v): |
| """Reads the value of a variable. |
| |
| Returns the aggregate value of a replica-local variable, or the |
| (read-only) value of any other variable. |
| |
| Args: |
| v: A variable allocated within the scope of this `tf.distribute.Strategy`. |
| |
| Returns: |
| A tensor representing the value of `v`, aggregated across replicas if |
| necessary. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def experimental_between_graph(self): |
| """Whether the strategy uses between-graph replication or not. |
| |
| This is expected to return a constant value that will not be changed |
| throughout its life cycle. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def experimental_should_init(self): |
| """Whether initialization is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_checkpoint(self): |
| """Whether checkpointing is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_save_summary(self): |
| """Whether saving summaries is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| # A note about the difference between the context managers |
| # `ReplicaContext` (defined here) and `_CurrentDistributionContext` |
| # (defined above) used by `tf.distribute.Strategy.scope()`: |
| # |
| # * a ReplicaContext is only present during a `run()` |
| # call (except during a `merge_run` call) and in such a scope it |
| # will be returned by calls to `get_replica_context()`. Implementers of new |
| # Strategy descendants will frequently also need to |
| # define a descendant of ReplicaContext, and are responsible for |
| # entering and exiting this context. |
| # |
| # * Strategy.scope() sets up a variable_creator scope that |
| # changes variable creation calls (e.g. to make mirrored |
| # variables). This is intended as an outer scope that users enter once |
| # around their model creation and graph definition. There is no |
| # anticipated need to define descendants of _CurrentDistributionContext. |
| # It sets the current Strategy for purposes of |
| # `get_strategy()` and `has_strategy()` |
| # and switches the thread mode to a "cross-replica context". |
| @tf_export("distribute.ReplicaContext") |
| class ReplicaContext(object): |
| """A class with a collection of APIs that can be called in a replica context. |
| |
| You can use `tf.distribute.get_replica_context` to get an instance of |
| `ReplicaContext`, which can only be called inside the function passed to |
| `tf.distribute.Strategy.run`. |
| |
| >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) |
| >>> def func(): |
| ... replica_context = tf.distribute.get_replica_context() |
| ... return replica_context.replica_id_in_sync_group |
| >>> strategy.run(func) |
| PerReplica:{ |
| 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, |
| 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> |
| } |
| """ |
| |
| def __init__(self, strategy, replica_id_in_sync_group): |
| """Creates a ReplicaContext. |
| |
| Args: |
| strategy: A `tf.distribute.Strategy`. |
| replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an |
| integer whenever possible to avoid issues with nested `tf.function`. It |
| accepts a `Tensor` only to be compatible with `tpu.replicate`. |
| """ |
| self._strategy = strategy |
| self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access |
| self) |
| if not (replica_id_in_sync_group is None or |
| tensor_util.is_tensor(replica_id_in_sync_group) or |
| isinstance(replica_id_in_sync_group, int)): |
| raise ValueError( |
| "replica_id_in_sync_group can only be an integer, a Tensor or None.") |
| self._replica_id_in_sync_group = replica_id_in_sync_group |
| self._summary_recording_distribution_strategy = None |
| |
| @doc_controls.do_not_generate_docs |
| def __enter__(self): |
| _push_per_thread_mode(self._thread_context) |
| |
| def replica_id_is_zero(): |
| return math_ops.equal(self.replica_id_in_sync_group, |
| constant_op.constant(0)) |
| |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| self._summary_recording_distribution_strategy = ( |
| summary_state.is_recording_distribution_strategy) |
| summary_state.is_recording_distribution_strategy = replica_id_is_zero |
| |
| @doc_controls.do_not_generate_docs |
| def __exit__(self, exception_type, exception_value, traceback): |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| summary_state.is_recording_distribution_strategy = ( |
| self._summary_recording_distribution_strategy) |
| _pop_per_thread_mode() |
| |
| def merge_call(self, merge_fn, args=(), kwargs=None): |
| """Merge args across replicas and run `merge_fn` in a cross-replica context. |
| |
| This allows communication and coordination when there are multiple calls |
| to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. |
| |
| See `tf.distribute.Strategy.run` for an explanation. |
| |
| If not inside a distributed scope, this is equivalent to: |
| |
| ``` |
| strategy = tf.distribute.get_strategy() |
| with cross-replica-context(strategy): |
| return merge_fn(strategy, *args, **kwargs) |
| ``` |
| |
| Args: |
| merge_fn: Function that joins arguments from threads that are given as |
| PerReplica. It accepts `tf.distribute.Strategy` object as |
| the first argument. |
| args: List or tuple with positional per-thread arguments for `merge_fn`. |
| kwargs: Dict with keyword per-thread arguments for `merge_fn`. |
| |
| Returns: |
| The return value of `merge_fn`, except for `PerReplica` values which are |
| unpacked. |
| """ |
| require_replica_context(self) |
| if kwargs is None: |
| kwargs = {} |
| merge_fn = autograph.tf_convert( |
| merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) |
| return self._merge_call(merge_fn, args, kwargs) |
| |
| def _merge_call(self, merge_fn, args, kwargs): |
| """Default implementation for single replica.""" |
| _push_per_thread_mode( # thread-local, so not needed with multiple threads |
| distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access |
| try: |
| return merge_fn(self._strategy, *args, **kwargs) |
| finally: |
| _pop_per_thread_mode() |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas that are kept in sync.""" |
| return self._strategy.num_replicas_in_sync |
| |
| @property |
| def replica_id_in_sync_group(self): |
| """Returns the id of the replica. |
| |
| This identifies the replica among all replicas that are kept in sync. The |
| value of the replica id can range from 0 to |
| `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. |
| |
| NOTE: This is not guaranteed to be the same ID as the XLA replica ID use |
| for low-level operations such as collective_permute. |
| |
| Returns: |
| a `Tensor`. |
| """ |
| # It's important to prefer making the Tensor at call time whenever possible. |
| # Keeping Tensors in global states doesn't work well with nested |
| # tf.function, since it's possible that the tensor is generated in one func |
| # graph, and gets captured by another, which will result in a subtle "An op |
| # outside of the function building code is being passed a Graph tensor" |
| # error. Making the tensor at call time to ensure it is the same graph where |
| # it's used. However to be compatible with tpu.replicate(), |
| # self._replica_id_in_sync_group can also be a Tensor. |
| if tensor_util.is_tensor(self._replica_id_in_sync_group): |
| return self._replica_id_in_sync_group |
| return constant_op.constant( |
| self._replica_id_in_sync_group, |
| dtypes.int32, |
| name="replica_id_in_sync_group") |
| |
| @property |
| def strategy(self): |
| """The current `tf.distribute.Strategy` object.""" |
| return self._strategy |
| |
| @property |
| @deprecation.deprecated(None, "Please avoid relying on devices property.") |
| def devices(self): |
| """Returns the devices this replica is to be executed on, as a tuple of strings. |
| |
| NOTE: For `tf.distribute.MirroredStrategy` and |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a |
| nested |
| list of device strings, e.g, [["GPU:0"]]. |
| """ |
| require_replica_context(self) |
| return (device_util.current(),) |
| |
| def all_reduce(self, reduce_op, value, experimental_hints=None): |
| """All-reduces the given `value Tensor` nest across replicas. |
| |
| If `all_reduce` is called in any replica, it must be called in all replicas. |
| The nested structure and `Tensor` shapes must be identical in all replicas. |
| |
| IMPORTANT: The ordering of communications must be identical in all replicas. |
| |
| Example with two replicas: |
| Replica 0 `value`: {'a': 1, 'b': [40, 1]} |
| Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} |
| |
| If `reduce_op` == `SUM`: |
| Result (on all replicas): {'a': 4, 'b': [42, 99]} |
| |
| If `reduce_op` == `MEAN`: |
| Result (on all replicas): {'a': 2, 'b': [21, 49.5]} |
| |
| Args: |
| reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. |
| value: The nested structure of `Tensor`s to all-reduce. The structure must |
| be compatible with `tf.nest`. |
| experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints |
| to perform collective operations. |
| |
| Returns: |
| A `Tensor` nest with the reduced `value`s from each replica. |
| """ |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if experimental_hints is None: |
| experimental_hints = collective_util.Hints() |
| |
| def batch_all_reduce(strategy, *value_flat): |
| return strategy.extended.batch_reduce_to( |
| reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], |
| experimental_hints) |
| |
| if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: |
| # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. |
| @custom_gradient.custom_gradient |
| def grad_wrapper(*xs): |
| ys = self.merge_call(batch_all_reduce, args=xs) |
| # The gradient of an all-sum is itself an all-sum (all-mean, likewise). |
| return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) |
| return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) |
| else: |
| # TODO(cjfj): Implement gradients for other reductions. |
| reduced = nest.pack_sequence_as( |
| value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) |
| return nest.map_structure(array_ops.prevent_gradient, reduced) |
| |
| # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient |
| # all-reduce. It would return a function returning the result of reducing `t` |
| # across all replicas. The caller would wait to call this function until they |
| # needed the reduce result, allowing an efficient implementation: |
| # * With eager execution, the reduction could be performed asynchronously |
| # in the background, not blocking until the result was needed. |
| # * When constructing a graph, it could batch up all reduction requests up |
| # to that point that the first result is needed. Most likely this can be |
| # implemented in terms of `merge_call()` and `batch_reduce_to()`. |
| |
| |
| def _batch_reduce_destination(x): |
| """Returns the destinations for batch all-reduce.""" |
| if isinstance(x, ops.Tensor): |
| # If this is a one device strategy. |
| return x.device |
| else: |
| return x |
| |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| _creating_default_strategy_singleton = False |
| |
| |
| class _DefaultDistributionStrategyV1(StrategyV1): |
| """Default `tf.distribute.Strategy` if none is explicitly selected.""" |
| |
| def __init__(self): |
| if not _creating_default_strategy_singleton: |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| super(_DefaultDistributionStrategyV1, |
| self).__init__(_DefaultDistributionExtended(self)) |
| |
| def __deepcopy__(self, memo): |
| del memo |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| |
| |
| class _DefaultDistributionStrategy(Strategy): |
| """Default `tf.distribute.Strategy` if none is explicitly selected.""" |
| |
| def __init__(self): |
| if not _creating_default_strategy_singleton: |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| super(_DefaultDistributionStrategy, self).__init__( |
| _DefaultDistributionExtended(self)) |
| |
| def __deepcopy__(self, memo): |
| del memo |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| |
| |
| class _DefaultDistributionContext(object): |
| """Context manager setting the default `tf.distribute.Strategy`.""" |
| |
| __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] |
| |
| def __init__(self, strategy): |
| |
| def creator(next_creator, **kwargs): |
| _require_strategy_scope_strategy(strategy) |
| return next_creator(**kwargs) |
| |
| self._var_creator_scope = variable_scope.variable_creator_scope(creator) |
| self._strategy = strategy |
| self._nested_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") |
| if self._nested_count == 0: |
| self._var_creator_scope.__enter__() |
| self._nested_count += 1 |
| return self._strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| self._nested_count -= 1 |
| if self._nested_count == 0: |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| |
| class _DefaultDistributionExtended(StrategyExtendedV1): |
| """Implementation of _DefaultDistributionStrategy.""" |
| |
| def __init__(self, container_strategy): |
| super(_DefaultDistributionExtended, self).__init__(container_strategy) |
| self._retrace_functions_for_each_device = False |
| |
| def _scope(self, strategy): |
| """Context manager setting a variable creator and `self` as current.""" |
| return _DefaultDistributionContext(strategy) |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Does not require `self.scope`.""" |
| _require_strategy_scope_extended(self) |
| return ops.colocate_with(colocate_with_variable) |
| |
| def variable_created_in_scope(self, v): |
| return v._distribute_strategy is None # pylint: disable=protected-access |
| |
| def _experimental_distribute_dataset(self, dataset, options): |
| return dataset |
| |
| def _experimental_distribute_datasets_from_function(self, dataset_fn, |
| options): |
| return dataset_fn(InputContext()) |
| |
| def _experimental_distribute_values_from_function(self, value_fn): |
| return value_fn(ValueContext()) |
| |
| def _make_dataset_iterator(self, dataset): |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| dataset = input_fn(InputContext()) |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| numpy_flat = nest.flatten(numpy_input) |
| vars_flat = tuple( |
| variable_scope.variable(array_ops.zeros(i.shape, i.dtype), |
| trainable=False, use_resource=True) |
| for i in numpy_flat |
| ) |
| for v, i in zip(vars_flat, numpy_flat): |
| numpy_dataset.init_var_from_numpy(v, i, session) |
| vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) |
| return dataset_ops.Dataset.from_tensor_slices(vars_nested) |
| |
| def _broadcast_to(self, tensor, destinations): |
| if destinations is None: |
| return tensor |
| else: |
| raise NotImplementedError("TODO") |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): |
| return fn(*args, **kwargs) |
| |
| def _reduce_to(self, reduce_op, value, destinations, experimental_hints): |
| # TODO(josh11b): Use destinations? |
| del reduce_op, destinations, experimental_hints |
| return value |
| |
| def _update(self, var, fn, args, kwargs, group): |
| # The implementations of _update() and _update_non_slot() are identical |
| # except _update() passes `var` as the first argument to `fn()`. |
| return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): |
| # TODO(josh11b): Figure out what we should be passing to UpdateContext() |
| # once that value is used for something. |
| with UpdateContext(colocate_with): |
| result = fn(*args, **kwargs) |
| if should_group: |
| return result |
| else: |
| return nest.map_structure(self._local_results, result) |
| |
| def read_var(self, replica_local_var): |
| return array_ops.identity(replica_local_var) |
| |
| def _local_results(self, distributed_value): |
| return (distributed_value,) |
| |
| def value_container(self, value): |
| return value |
| |
| @property |
| def _num_replicas_in_sync(self): |
| return 1 |
| |
| @property |
| def worker_devices(self): |
| raise RuntimeError("worker_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| @property |
| def parameter_devices(self): |
| raise RuntimeError("parameter_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| def non_slot_devices(self, var_list): |
| return min(var_list, key=lambda x: x.name) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings.""" |
| # Default strategy doesn't indicate multi-worker training. |
| return False |
| |
| @property |
| def should_checkpoint(self): |
| return True |
| |
| @property |
| def should_save_summary(self): |
| return True |
| |
| # TODO(priyag): This should inherit from `InputIterator`, once dependency |
| # issues have been resolved. |
| class DefaultInputIterator(object): |
| """Default implementation of `InputIterator` for default strategy.""" |
| |
| def __init__(self, dataset): |
| self._dataset = dataset |
| if eager_context.executing_eagerly(): |
| self._iterator = dataset_ops.make_one_shot_iterator(dataset) |
| else: |
| self._iterator = dataset_ops.make_initializable_iterator(dataset) |
| |
| def get_next(self): |
| return self._iterator.get_next() |
| |
| def get_next_as_optional(self): |
| return iterator_ops.get_next_as_optional(self._iterator) |
| |
| @deprecated(None, "Use the iterator's `initializer` property instead.") |
| def initialize(self): |
| """Initialize underlying iterators. |
| |
| Returns: |
| A list of any initializer ops that should be run. |
| """ |
| if eager_context.executing_eagerly(): |
| self._iterator = self._dataset.make_one_shot_iterator() |
| return [] |
| else: |
| return [self._iterator.initializer] |
| |
| @property |
| def initializer(self): |
| """Returns a list of ops that initialize the iterator.""" |
| return self.initialize() |
| |
| # TODO(priyag): Delete this once all strategies use global batch size. |
| @property |
| def _global_batch_size(self): |
| """Global and per-replica batching are equivalent for this strategy.""" |
| return True |
| |
| |
| class _DefaultReplicaContext(ReplicaContext): |
| """ReplicaContext for _DefaultDistributionStrategy.""" |
| |
| @property |
| def replica_id_in_sync_group(self): |
| # Return 0 instead of a constant tensor to avoid creating a new node for |
| # users who don't use distribution strategy. |
| return 0 |
| |
| |
| # ------------------------------------------------------------------------------ |
| # We haven't yet implemented deserialization for DistributedVariables. |
| # So here we catch any attempts to deserialize variables |
| # when using distribution strategies. |
| # pylint: disable=protected-access |
| _original_from_proto = resource_variable_ops._from_proto_fn |
| |
| |
| def _from_proto_fn(v, import_scope=None): |
| if distribution_strategy_context.has_strategy(): |
| raise NotImplementedError( |
| "Deserialization of variables is not yet supported when using a " |
| "tf.distribute.Strategy.") |
| else: |
| return _original_from_proto(v, import_scope=import_scope) |
| |
| resource_variable_ops._from_proto_fn = _from_proto_fn |
| # pylint: enable=protected-access |
| |
| |
| #------------------------------------------------------------------------------- |
| # Shorthand for some methods from distribution_strategy_context. |
| _push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access |
| _get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access |
| _pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access |
| _get_default_replica_mode = ( |
| distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Metrics to track which distribution strategy is being called |
| distribution_strategy_gauge = monitoring.StringGauge( |
| "/tensorflow/api/distribution_strategy", |
| "Gauge to track the type of distribution strategy used.", "TFVersion") |
| distribution_strategy_replica_gauge = monitoring.IntGauge( |
| "/tensorflow/api/distribution_strategy/replica", |
| "Gauge to track the number of replica each distribution strategy used.", |
| "CountType") |