| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Library for running a computation across multiple devices. |
| |
| See the guide for overview and examples: |
| [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training), |
| [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb). # pylint: disable=line-too-long |
| |
| The intent of this library is that you can write an algorithm in a stylized way |
| and it will be usable with a variety of different `tf.distribute.Strategy` |
| implementations. Each descendant will implement a different strategy for |
| distributing the algorithm across multiple devices/machines. Furthermore, these |
| changes can be hidden inside the specific layers and other library classes that |
| need special treatment to run in a distributed setting, so that most users' |
| model definition code can run unchanged. The `tf.distribute.Strategy` API works |
| the same way with eager and graph execution. |
| |
| *Glossary* |
| |
| * _Data parallelism_ is where we run multiple copies of the model |
| on different slices of the input data. This is in contrast to |
| _model parallelism_ where we divide up a single copy of a model |
| across multiple devices. |
| Note: we only support data parallelism for now, but |
| hope to add support for model parallelism in the future. |
| * A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that |
| TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple |
| devices on a single machine, or be connected to devices on multiple |
| machines. Devices used to run computations are called _worker devices_. |
| Devices used to store variables are _parameter devices_. For some strategies, |
| such as `tf.distribute.MirroredStrategy`, the worker and parameter devices |
| will be the same (see mirrored variables below). For others they will be |
| different. For example, `tf.distribute.experimental.CentralStorageStrategy` |
| puts the variables on a single device (which may be a worker device or may be |
| the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the |
| variables on separate machines called parameter servers (see below). |
| * A _replica_ is one copy of the model, running on one slice of the |
| input data. Right now each replica is executed on its own |
| worker device, but once we add support for model parallelism |
| a replica may span multiple worker devices. |
| * A _host_ is the CPU device on a machine with worker devices, typically |
| used for running input pipelines. |
| * A _worker_ is defined to be the physical machine(s) containing the physical |
| devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A |
| worker may contain one or more replicas, but contains at least one |
| replica. Typically one worker will correspond to one machine, but in the case |
| of very large models with model parallelism, one worker may span multiple |
| machines. We typically run one input pipeline per worker, feeding all the |
| replicas on that worker. |
| * _Synchronous_, or more commonly _sync_, training is where the updates from |
| each replica are aggregated together before updating the model variables. This |
| is in contrast to _asynchronous_, or _async_ training, where each replica |
| updates the model variables independently. You may also have replicas |
| partitioned into groups which are in sync within each group but async between |
| groups. |
| * _Parameter servers_: These are machines that hold a single copy of |
| parameters/variables, used by some strategies (right now just |
| `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want |
| to operate on a variable retrieve it at the beginning of a step and send an |
| update to be applied at the end of the step. These can in priniciple support |
| either sync or async training, but right now we only have support for async |
| training with parameter servers. Compare to |
| `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables |
| on a single device on the same machine (and does sync training), and |
| `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices |
| (see below). |
| * _Mirrored variables_: These are variables that are copied to multiple |
| devices, where we keep the copies in sync by applying the same |
| updates to every copy. Normally would only be used with sync training. |
| * Reductions and all-reduce: A _reduction_ is some method of aggregating |
| multiple values into one value, like "sum" or "mean". If a strategy is doing |
| sync training, we will perform a reduction on the gradients to a parameter |
| from all replicas before applying the update. _All-reduce_ is an algorithm for |
| performing a reduction on values from multiple devices and making the result |
| available on all of those devices. |
| |
| Note that we provide a default version of `tf.distribute.Strategy` that is |
| used when no other strategy is in scope, that provides the same API with |
| reasonable default behavior. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import copy |
| import enum # pylint: disable=g-bad-import-order |
| import threading |
| import weakref |
| |
| import six |
| |
| from tensorflow.python.autograph.core import ag_ctx |
| from tensorflow.python.autograph.impl import api as autograph |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute import device_util |
| from tensorflow.python.distribute import distribution_strategy_context |
| from tensorflow.python.distribute import numpy_dataset |
| from tensorflow.python.distribute import reduce_util |
| from tensorflow.python.eager import context as eager_context |
| from tensorflow.python.eager import monitoring |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import summary_ops_v2 |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops.losses import loss_reduction |
| from tensorflow.python.ops.losses import losses_impl |
| from tensorflow.python.platform import tf_logging |
| from tensorflow.python.training.tracking import base as trackable |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_contextlib |
| from tensorflow.python.util.tf_export import tf_export |
| from tensorflow.tools.docs import doc_controls |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Context tracking whether in a strategy.update() or .update_non_slot() call. |
| |
| |
| _update_device = threading.local() |
| |
| |
| def get_update_device(): |
| """Get the current device if in a `tf.distribute.Strategy.update()` call.""" |
| try: |
| return _update_device.current |
| except AttributeError: |
| return None |
| |
| |
| class UpdateContext(object): |
| """Context manager when you are in `update()` or `update_non_slot()`.""" |
| |
| def __init__(self, device): |
| self._device = device |
| self._old_device = None |
| |
| def __enter__(self): |
| self._old_device = get_update_device() |
| _update_device.current = self._device |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| del exception_type, exception_value, traceback |
| _update_device.current = self._old_device |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Public utility functions. |
| |
| |
| @tf_export(v1=["distribute.get_loss_reduction"]) |
| def get_loss_reduction(): |
| """`tf.distribute.ReduceOp` corresponding to the last loss reduction. |
| |
| This is used to decide whether loss should be scaled in optimizer (used only |
| for estimator + v1 optimizer use case). |
| |
| Returns: |
| `tf.distribute.ReduceOp` corresponding to the last loss reduction for |
| estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. |
| """ |
| if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access |
| # If we are not in Estimator context then return 'SUM'. We do not need to |
| # scale loss in the optimizer. |
| return reduce_util.ReduceOp.SUM |
| last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access |
| if (last_reduction == losses_impl.Reduction.SUM or |
| last_reduction == loss_reduction.ReductionV2.SUM): |
| return reduce_util.ReduceOp.SUM |
| return reduce_util.ReduceOp.MEAN |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal API for validating the current thread mode |
| |
| |
| def _require_cross_replica_or_default_context_extended(extended): |
| """Verify in cross-replica context.""" |
| context = _get_per_thread_mode() |
| cross_replica = context.cross_replica_context |
| if cross_replica is not None and cross_replica.extended is extended: |
| return |
| if context is _get_default_replica_mode(): |
| return |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| # We have an error to report, figure out the right message. |
| if context.strategy is not strategy: |
| _wrong_strategy_scope(strategy, context) |
| assert cross_replica is None |
| raise RuntimeError("Method requires being in cross-replica context, use " |
| "get_replica_context().merge_call()") |
| |
| |
| def _wrong_strategy_scope(strategy, context): |
| # Figure out the right error message. |
| if not distribution_strategy_context.has_strategy(): |
| raise RuntimeError( |
| 'Need to be inside "with strategy.scope()" for %s' % |
| (strategy,)) |
| else: |
| raise RuntimeError( |
| "Mixing different tf.distribute.Strategy objects: %s is not %s" % |
| (context.strategy, strategy)) |
| |
| |
| def require_replica_context(replica_ctx): |
| """Verify in `replica_ctx` replica context.""" |
| context = _get_per_thread_mode() |
| if context.replica_context is replica_ctx: return |
| # We have an error to report, figure out the right message. |
| if context.replica_context is None: |
| raise RuntimeError("Need to be inside `call_for_each_replica()`") |
| if context.strategy is replica_ctx.strategy: |
| # Two different ReplicaContexts with the same tf.distribute.Strategy. |
| raise RuntimeError("Mismatching ReplicaContext.") |
| raise RuntimeError( |
| "Mismatching tf.distribute.Strategy objects: %s is not %s." % |
| (context.strategy, replica_ctx.strategy)) |
| |
| |
| def _require_strategy_scope_strategy(strategy): |
| """Verify in a `strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy is strategy: return |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| def _require_strategy_scope_extended(extended): |
| """Verify in a `distribution_strategy.scope()` in this thread.""" |
| context = _get_per_thread_mode() |
| if context.strategy.extended is extended: return |
| # Report error. |
| strategy = extended._container_strategy() # pylint: disable=protected-access |
| _wrong_strategy_scope(strategy, context) |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Internal context managers used to implement the DistributionStrategy |
| # base class |
| |
| |
| class _CurrentDistributionContext(object): |
| """Context manager setting the current `tf.distribute.Strategy`. |
| |
| Also: overrides the variable creator and optionally the current device. |
| """ |
| |
| def __init__(self, |
| strategy, |
| var_creator_scope, |
| var_scope=None, |
| default_device=None): |
| self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access |
| strategy) |
| self._var_creator_scope = var_creator_scope |
| self._var_scope = var_scope |
| if default_device: |
| self._device_scope = ops.device(default_device) |
| else: |
| self._device_scope = None |
| self._same_scope_again_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| _require_cross_replica_or_default_context_extended( |
| self._context.strategy.extended) |
| self._same_scope_again_count += 1 |
| else: |
| _push_per_thread_mode(self._context) |
| if self._var_scope: |
| self._var_scope.__enter__() |
| self._var_creator_scope.__enter__() |
| if self._device_scope: |
| self._device_scope.__enter__() |
| return self._context.strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| if self._same_scope_again_count > 0: |
| self._same_scope_again_count -= 1 |
| return |
| if self._device_scope: |
| try: |
| self._device_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Device scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| if self._var_scope: |
| try: |
| self._var_scope.__exit__(exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| _pop_per_thread_mode() |
| |
| |
| # TODO(yuefengz): add more replication modes. |
| @tf_export("distribute.InputReplicationMode") |
| class InputReplicationMode(enum.Enum): |
| """Replication mode for input function. |
| |
| * `PER_WORKER`: The input function will be called on each worker |
| independently, creating as many input pipelines as number of workers. |
| Replicas will dequeue from the local Dataset on their worker. |
| `tf.distribute.Strategy` doesn't manage any state sharing between such |
| separate input pipelines. |
| """ |
| PER_WORKER = "PER_WORKER" |
| |
| |
| @tf_export("distribute.InputContext") |
| class InputContext(object): |
| """A class wrapping information needed by an input function. |
| |
| This is a context class that is passed to the user's input function and |
| contains information about the compute replicas and input pipelines. The |
| number of compute replicas (in sync training) helps compute the local batch |
| size from the desired global batch size for each replica. The input pipeline |
| information can be used to return a different subset of the input in each |
| replica (for e.g. shard the input pipeline, use a different input |
| source etc). |
| """ |
| |
| def __init__(self, |
| num_input_pipelines=1, |
| input_pipeline_id=0, |
| num_replicas_in_sync=1): |
| """Initializes an InputContext object. |
| |
| Args: |
| num_input_pipelines: the number of input pipelines in a cluster. |
| input_pipeline_id: the current input pipeline id, should be an int in |
| [0,`num_input_pipelines`). |
| num_replicas_in_sync: the number of replicas that are in sync. |
| """ |
| self._num_input_pipelines = num_input_pipelines |
| self._input_pipeline_id = input_pipeline_id |
| self._num_replicas_in_sync = num_replicas_in_sync |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns the number of compute replicas in sync.""" |
| return self._num_replicas_in_sync |
| |
| @property |
| def input_pipeline_id(self): |
| """Returns the input pipeline ID.""" |
| return self._input_pipeline_id |
| |
| @property |
| def num_input_pipelines(self): |
| """Returns the number of input pipelines.""" |
| return self._num_input_pipelines |
| |
| def get_per_replica_batch_size(self, global_batch_size): |
| """Returns the per-replica batch size. |
| |
| Args: |
| global_batch_size: the global batch size which should be divisible by |
| `num_replicas_in_sync`. |
| |
| Returns: |
| the per-replica batch size. |
| |
| Raises: |
| ValueError: if `global_batch_size` not divisible by |
| `num_replicas_in_sync`. |
| """ |
| if global_batch_size % self._num_replicas_in_sync != 0: |
| raise ValueError("The `global_batch_size` %r is not divisible by " |
| "`num_replicas_in_sync` %r " % |
| (global_batch_size, self._num_replicas_in_sync)) |
| return global_batch_size // self._num_replicas_in_sync |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Base classes for all distribution strategies. |
| |
| |
| # pylint: disable=line-too-long |
| @tf_export("distribute.Strategy", v1=[]) |
| class Strategy(object): |
| """A state & compute distribution policy on a list of devices. |
| |
| See [the guide](https://www.tensorflow.org/guide/distributed_training) |
| for overview and examples. |
| |
| In short: |
| |
| * To use it with Keras `compile`/`fit`, |
| [please |
| read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). |
| * You may pass descendant of `tf.distribute.Strategy` to |
| `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` |
| should distribute its computation. See |
| [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). |
| * Otherwise, use `tf.distribute.Strategy.scope` to specify that a |
| strategy should be used when building an executing your model. |
| (This puts you in the "cross-replica context" for this strategy, which |
| means the strategy is put in control of things like variable placement.) |
| * If you are writing a custom training loop, you will need to call a few more |
| methods, |
| [see the |
| guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): |
| |
| * Start by either creating a `tf.data.Dataset` normally or using |
| `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of |
| a `numpy` array. |
| * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert |
| a `tf.data.Dataset` to something that produces "per-replica" values. |
| If you want to manually specify how the dataset should be partitioned |
| across replicas, use |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function` |
| instead. |
| * Use `tf.distribute.Strategy.experimental_run_v2` to run a function |
| once per replica, taking values that may be "per-replica" (e.g. |
| from a distributed dataset) and returning "per-replica" values. |
| This function is executed in "replica context", which means each |
| operation is performed separately on each replica. |
| * Finally use a method (such as `tf.distribute.Strategy.reduce`) to |
| convert the resulting "per-replica" values into ordinary `Tensor`s. |
| |
| A custom training loop can be as simple as: |
| |
| ``` |
| with my_strategy.scope(): |
| @tf.function |
| def distribute_train_epoch(dataset): |
| def replica_fn(input): |
| # process input and return result |
| return result |
| |
| total_result = 0 |
| for x in dataset: |
| per_replica_result = my_strategy.experimental_run_v2(replica_fn, |
| args=(x,)) |
| total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, |
| per_replica_result, axis=None) |
| return total_result |
| |
| dist_dataset = my_strategy.experimental_distribute_dataset(dataset) |
| for _ in range(EPOCHS): |
| train_result = distribute_train_epoch(dist_dataset) |
| ``` |
| |
| This takes an ordinary `dataset` and `replica_fn` and runs it |
| distributed using a particular `tf.distribute.Strategy` named |
| `my_strategy` above. Any variables created in `replica_fn` are created |
| using `my_strategy`'s policy, and library functions called by |
| `replica_fn` can use the `get_replica_context()` API to implement |
| distributed-specific behavior. |
| |
| You can use the `reduce` API to aggregate results across replicas and use |
| this as a return value from one iteration over the distributed dataset. Or |
| you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to |
| accumulate metrics across steps in a given epoch. |
| |
| See the |
| [custom training loop |
| tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) |
| for a more detailed example. |
| |
| Note: `tf.distribute.Strategy` currently does not support TensorFlow's |
| partitioned variables (where a single variable is split across multiple |
| devices) at this time. |
| """ |
| # pylint: enable=line-too-long |
| |
| # TODO(josh11b): Partitioned computations, state; sharding |
| # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling |
| |
| def __init__(self, extended): |
| self._extended = extended |
| |
| # Flag that is used to indicate whether distribution strategy is used with |
| # Estimator. This is required for backward compatibility of loss scaling |
| # when using v1 optimizer with estimator. |
| self._scale_loss_for_estimator = False |
| |
| if not hasattr(extended, "_retrace_functions_for_each_device"): |
| # pylint: disable=protected-access |
| try: |
| extended._retrace_functions_for_each_device = ( |
| len(extended.worker_devices) > 1) |
| distribution_strategy_replica_gauge.get_cell("num_replicas").set( |
| self.num_replicas_in_sync) |
| except: # pylint: disable=bare-except |
| # Default for the case where extended.worker_devices can't return |
| # a sensible value. |
| extended._retrace_functions_for_each_device = True |
| |
| @property |
| def extended(self): |
| """`tf.distribute.StrategyExtended` with additional methods.""" |
| return self._extended |
| |
| @tf_contextlib.contextmanager |
| def _scale_loss_for_estimator_enabled(self): |
| """Scope which sets a flag used for scaling losses in optimizer. |
| |
| Yields: |
| `_scale_loss_for_estimator_enabled` is a context manager with a |
| side effect, but doesn't return a value. |
| """ |
| self._scale_loss_for_estimator = True |
| try: |
| yield |
| finally: |
| self._scale_loss_for_estimator = False |
| |
| def scope(self): |
| """Returns a context manager selecting this Strategy as current. |
| |
| Inside a `with strategy.scope():` code block, this thread |
| will use a variable creator set by `strategy`, and will |
| enter its "cross-replica context". |
| |
| Returns: |
| A context manager. |
| """ |
| return self._extended._scope(self) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` |
| def colocate_vars_with(self, colocate_with_variable): |
| """DEPRECATED: use extended.colocate_vars_with() instead.""" |
| return self._extended.colocate_vars_with(colocate_with_variable) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_dataset_iterator(self, dataset): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """DEPRECATED TF 1.x ONLY.""" |
| if replication_mode != InputReplicationMode.PER_WORKER: |
| raise ValueError( |
| "Input replication mode not supported: %r" % replication_mode) |
| with self.scope(): |
| return self.extended._make_input_fn_iterator( # pylint: disable=protected-access |
| input_fn, replication_mode=replication_mode) |
| |
| def experimental_make_numpy_dataset(self, numpy_input): |
| """Makes a `tf.data.Dataset` for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Note that you will likely need to use `experimental_distribute_dataset` |
| with the returned dataset to further distribute it with the strategy. |
| |
| Example: |
| ``` |
| numpy_input = np.ones([10], dtype=np.float32) |
| dataset = strategy.experimental_make_numpy_dataset(numpy_input) |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| ``` |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be converted into a |
| dataset. Note that lists of Numpy arrays are stacked, as that is normal |
| `tf.data.Dataset` behavior. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| return self.extended.experimental_make_numpy_dataset( |
| numpy_input, session=None) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only |
| def experimental_run(self, fn, input_iterator=None): |
| """DEPRECATED TF 1.x ONLY.""" |
| with self.scope(): |
| args = (input_iterator.get_next(),) if input_iterator is not None else () |
| return self.experimental_run_v2(fn, args=args) |
| |
| def experimental_distribute_dataset(self, dataset): |
| """Distributes a tf.data.Dataset instance provided via `dataset`. |
| |
| The returned distributed dataset can be iterated over similar to how |
| regular datasets can. |
| NOTE: Currently, the user cannot add any more transformations to a |
| distributed dataset. |
| |
| The following is an example: |
| |
| ```python |
| strategy = tf.distribute.MirroredStrategy() |
| |
| # Create a dataset |
| dataset = dataset_ops.Dataset.TFRecordDataset([ |
| "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) |
| |
| # Distribute that dataset |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| # Iterate over the distributed dataset |
| for x in dist_dataset: |
| # process dataset elements |
| strategy.experimental_run_v2(train_step, args=(x,)) |
| ``` |
| |
| We will assume that the input dataset is batched by the |
| global batch size. With this assumption, we will make a best effort to |
| divide each batch across all the replicas (one or more workers). |
| |
| In a multi-worker setting, we will first attempt to distribute the dataset |
| by attempting to detect whether the dataset is being created out of |
| ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so, |
| attempting to shard the input files. Note that there has to be at least one |
| input file per worker. If you have less than one input file per worker, we |
| suggest that you should disable distributing your dataset using the method |
| below. |
| |
| If that attempt is unsuccessful (e.g. the dataset is created from a |
| Dataset.range), we will shard the dataset evenly at the end by appending a |
| `.shard` operation to the end of the processing pipeline. This will cause |
| the entire preprocessing pipeline for all the data to be run on every |
| worker, and each worker will do redundant work. We will print a warning |
| if this method of sharding is selected. In this case, consider using |
| `experimental_distribute_datasets_from_function` instead. |
| |
| You can disable dataset sharding across workers using the `auto_shard` |
| option in `tf.data.experimental.DistributeOptions`. |
| |
| Within each worker, we will also split the data among all the worker |
| devices (if more than one a present), and this will happen even if |
| multi-worker sharding is disabled using the method above. |
| |
| If the above batch splitting and dataset sharding logic is undesirable, |
| please use `experimental_distribute_datasets_from_function` instead, which |
| does not do any automatic splitting or sharding. |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be sharded across all replicas using |
| the rules stated above. |
| |
| Returns: |
| A "distributed `Dataset`", which acts like a `tf.data.Dataset` except |
| it produces "per-replica" values. |
| """ |
| return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access |
| |
| def experimental_distribute_datasets_from_function(self, dataset_fn): |
| """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. |
| |
| `dataset_fn` will be called once for each worker in the strategy. Each |
| replica on that worker will dequeue one batch of inputs from the local |
| `Dataset` (i.e. if a worker has two replicas, two batches will be dequeued |
| from the `Dataset` every step). |
| |
| This method can be used for several purposes. For example, where |
| `experimental_distribute_dataset` is unable to shard the input files, this |
| method might be used to manually shard the dataset (avoiding the slow |
| fallback behavior in `experimental_distribute_dataset`). In cases where the |
| dataset is infinite, this sharding can be done by creating dataset replicas |
| that differ only in their random seed. |
| `experimental_distribute_dataset` may also sometimes fail to split the |
| batch across replicas on a worker. In that case, this method can be used |
| where that limitation does not exist. |
| |
| The `dataset_fn` should take an `tf.distribute.InputContext` instance where |
| information about batching and input replication can be accessed: |
| |
| ``` |
| def dataset_fn(input_context): |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) |
| return d.shard( |
| input_context.num_input_pipelines, input_context.input_pipeline_id) |
| |
| inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) |
| |
| for batch in inputs: |
| replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,)) |
| ``` |
| |
| IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a |
| per-replica batch size, unlike `experimental_distribute_dataset`, which uses |
| the global batch size. This may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Args: |
| dataset_fn: A function taking a `tf.distribute.InputContext` instance and |
| returning a `tf.data.Dataset`. |
| |
| Returns: |
| A "distributed `Dataset`", which acts like a `tf.data.Dataset` except |
| it produces "per-replica" values. |
| """ |
| return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access |
| dataset_fn) |
| |
| def experimental_run_v2(self, fn, args=(), kwargs=None): |
| """Run `fn` on each replica, with the given arguments. |
| |
| Executes ops specified by `fn` on each replica. If `args` or `kwargs` have |
| "per-replica" values, such as those produced by a "distributed `Dataset`", |
| when `fn` is executed on a particular replica, it will be executed with the |
| component of those "per-replica" values that correspond to that replica. |
| |
| `fn` may call `tf.distribute.get_replica_context()` to access members such |
| as `all_reduce`. |
| |
| All arguments in `args` or `kwargs` should either be nest of tensors or |
| per-replica objects containing tensors or composite tensors. |
| |
| IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and |
| whether eager execution is enabled, `fn` may be called one or more times ( |
| once for each replica). |
| |
| Args: |
| fn: The function to run. The output must be a `tf.nest` of `Tensor`s. |
| args: (Optional) Positional arguments to `fn`. |
| kwargs: (Optional) Keyword arguments to `fn`. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be "per-replica" `Tensor` objects or `Tensor`s |
| (for example, if running on a single replica). |
| """ |
| with self.scope(): |
| # tf.distribute supports Eager functions, so AutoGraph should not be |
| # applied when when the caller is also in Eager mode. |
| fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx(), |
| convert_by_default=False) |
| return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) |
| |
| def reduce(self, reduce_op, value, axis): |
| """Reduce `value` across replicas. |
| |
| Given a per-replica value returned by `experimental_run_v2`, say a |
| per-example loss, the batch will be divided across all the replicas. This |
| function allows you to aggregate across replicas and optionally also across |
| batch elements. For example, if you have a global batch size of 8 and 2 |
| replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and |
| `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just |
| aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful |
| when each replica is computing a scalar or some other value that doesn't |
| have a "batch" dimension (like a gradient). More often you will want to |
| aggregate across the global batch, which you can get by specifying the batch |
| dimension as the `axis`, typically `axis=0`. In this case it would return a |
| scalar `0+1+2+3+4+5+6+7`. |
| |
| If there is a last partial batch, you will need to specify an axis so |
| that the resulting shape is consistent across replicas. So if the last |
| batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you |
| would get a shape mismatch unless you specify `axis=0`. If you specify |
| `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct |
| denominator of 6. Contrast this with computing `reduce_mean` to get a |
| scalar value on each replica and this function to average those means, |
| which will weigh some values `1/8` and others `1/4`. |
| |
| Args: |
| reduce_op: A `tf.distribute.ReduceOp` value specifying how values should |
| be combined. |
| value: A "per replica" value, e.g. returned by `experimental_run_v2` to |
| be combined into a single tensor. |
| axis: Specifies the dimension to reduce along within each |
| replica's tensor. Should typically be set to the batch dimension, or |
| `None` to only reduce across replicas (e.g. if the tensor has no batch |
| dimension). |
| |
| Returns: |
| A `Tensor`. |
| """ |
| # TODO(josh11b): support `value` being a nest. |
| _require_cross_replica_or_default_context_extended(self._extended) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| if axis is None: |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op == reduce_util.ReduceOp.SUM: |
| value = self.experimental_run_v2( |
| lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,)) |
| return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access |
| if reduce_op != reduce_util.ReduceOp.MEAN: |
| raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " |
| "not: %r" % reduce_op) |
| # TODO(josh11b): Support list/tuple and tensor axis values. |
| if not isinstance(axis, six.integer_types): |
| raise TypeError("Expected `axis` to be an integer not: %r" % axis) |
| |
| def mean_reduce_helper(v, axis=axis): |
| """Computes the numerator and denominator on each replica.""" |
| numer = math_ops.reduce_sum(v, axis=axis) |
| if v.shape.rank is not None: |
| # Note(joshl): We support axis < 0 to be consistent with the |
| # tf.math.reduce_* operations. |
| if axis < 0: |
| if axis + v.shape.rank < 0: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| axis += v.shape.rank |
| elif axis >= v.shape.rank: |
| raise ValueError( |
| "`axis` = %r out of range for `value` with rank %d" % |
| (axis, v.shape.rank)) |
| # TF v2 returns `None` for unknown dimensions and an integer for |
| # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) |
| # or tensor_shape.Dimension(integer). `dimension_value` hides this |
| # difference, always returning `None` or an integer. |
| dim = tensor_shape.dimension_value(v.shape[axis]) |
| if dim is not None: |
| # By returning a python value in the static shape case, we can |
| # maybe get a fast path for reducing the denominator. |
| return numer, dim |
| elif axis < 0: |
| axis = axis + array_ops.rank(v) |
| if v.shape.rank == 1: |
| # TODO(b/139422050): Currently tf.shape is not supported in TPU dynamic |
| # padder, use tf.size instead to workaround if the rank is 1. |
| denom = array_ops.size(v, out_type=dtypes.int64) |
| else: |
| denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis] |
| # TODO(josh11b): Should we cast denom to v.dtype here instead of after the |
| # reduce is complete? |
| return numer, denom |
| |
| numer, denom = self.experimental_run_v2(mean_reduce_helper, args=(value,)) |
| # TODO(josh11b): Should batch reduce here instead of doing two. |
| numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access |
| denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access |
| denom = math_ops.cast(denom, numer.dtype) |
| return math_ops.truediv(numer, denom) |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED |
| def unwrap(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| DEPRECATED: Please use `experimental_local_results` instead. |
| |
| Note: This only returns values on the workers initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, |
| `extended.call_for_each_replica()`, or a variable created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value`. If `value` represents a single |
| value, this returns `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| def experimental_local_results(self, value): |
| """Returns the list of all local per-replica values contained in `value`. |
| |
| Note: This only returns values on the worker initiated by this client. |
| When using a `tf.distribute.Strategy` like |
| `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker |
| will be its own client, and this function will only return values |
| computed on that worker. |
| |
| Args: |
| value: A value returned by `experimental_run()`, `experimental_run_v2()`, |
| `extended.call_for_each_replica()`, or a variable created in `scope`. |
| |
| Returns: |
| A tuple of values contained in `value`. If `value` represents a single |
| value, this returns `(value,).` |
| """ |
| return self._extended._local_results(value) # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only |
| def group(self, value, name=None): |
| """Shortcut for `tf.group(self.experimental_local_results(value))`.""" |
| return self._extended._group(value, name) # pylint: disable=protected-access |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| return self._extended._num_replicas_in_sync # pylint: disable=protected-access |
| |
| @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string |
| def configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| # pylint: disable=g-doc-return-or-yield,g-doc-args |
| """DEPRECATED: use `update_config_proto` instead. |
| |
| Configures the strategy class. |
| |
| DEPRECATED: This method's functionality has been split into the strategy |
| constructor and `update_config_proto`. In the future, we will allow passing |
| cluster and config_proto to the constructor to configure the strategy. And |
| `update_config_proto` can be used to update the config_proto based on the |
| specific strategy. |
| """ |
| return self._extended._configure( # pylint: disable=protected-access |
| session_config, cluster_spec, task_type, task_id) |
| |
| @doc_controls.do_not_generate_docs # DEPRECATED |
| def update_config_proto(self, config_proto): |
| """DEPRECATED TF 1.x ONLY.""" |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| def __deepcopy__(self, memo): |
| # First do a regular deepcopy of `self`. |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| for k, v in self.__dict__.items(): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| # One little fix-up: we want `result._extended` to reference `result` |
| # instead of `self`. |
| result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access |
| return result |
| |
| def __copy__(self): |
| raise RuntimeError("Must only deepcopy DistributionStrategy.") |
| |
| |
| # TF v1.x version has additional deprecated APIs |
| @tf_export(v1=["distribute.Strategy"]) |
| class StrategyV1(Strategy): |
| """A list of devices with a state & compute distribution policy. |
| |
| See [the guide](https://www.tensorflow.org/guide/distribute_strategy) |
| for overview and examples. |
| |
| Note: Not all `tf.distribute.Strategy` implementations currently support |
| TensorFlow's partitioned variables (where a single variable is split across |
| multiple devices) at this time. |
| """ |
| |
| def make_dataset_iterator(self, dataset): |
| """Makes an iterator for input provided via `dataset`. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| Data from the given dataset will be distributed evenly across all the |
| compute replicas. We will assume that the input dataset is batched by the |
| global batch size. With this assumption, we will make a best effort to |
| divide each batch across all the replicas (one or more workers). |
| If this effort fails, an error will be thrown, and the user should instead |
| use `make_input_fn_iterator` which provides more control to the user, and |
| does not try to divide a batch across replicas. |
| |
| The user could also use `make_input_fn_iterator` if they want to |
| customize which input is fed to which replica/worker etc. |
| |
| Args: |
| dataset: `tf.data.Dataset` that will be distributed evenly across all |
| replicas. |
| |
| Returns: |
| An `tf.distribute.InputIterator` which returns inputs for each step of the |
| computation. User should call `initialize` on the returned iterator. |
| """ |
| return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access |
| |
| def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| """Returns an iterator split across replicas created from an input function. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The `input_fn` should take an `tf.distribute.InputContext` object where |
| information about batching and input sharding can be accessed: |
| |
| ``` |
| def input_fn(input_context): |
| batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) |
| return d.shard(input_context.num_input_pipelines, |
| input_context.input_pipeline_id) |
| with strategy.scope(): |
| iterator = strategy.make_input_fn_iterator(input_fn) |
| replica_results = strategy.experimental_run(replica_fn, iterator) |
| ``` |
| |
| The `tf.data.Dataset` returned by `input_fn` should have a per-replica |
| batch size, which may be computed using |
| `input_context.get_per_replica_batch_size`. |
| |
| Args: |
| input_fn: A function taking a `tf.distribute.InputContext` object and |
| returning a `tf.data.Dataset`. |
| replication_mode: an enum value of `tf.distribute.InputReplicationMode`. |
| Only `PER_WORKER` is supported currently, which means there will be |
| a single call to `input_fn` per worker. Replicas will dequeue from the |
| local `tf.data.Dataset` on their worker. |
| |
| Returns: |
| An iterator object that should first be `.initialize()`-ed. It may then |
| either be passed to `strategy.experimental_run()` or you can |
| `iterator.get_next()` to get the next value to pass to |
| `strategy.extended.call_for_each_replica()`. |
| """ |
| return super(StrategyV1, self).make_input_fn_iterator( |
| input_fn, replication_mode) |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a tf.data.Dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Note that you will likely need to use |
| tf.distribute.Strategy.experimental_distribute_dataset |
| with the returned dataset to further distribute it with the strategy. |
| |
| Example: |
| ``` |
| numpy_input = np.ones([10], dtype=np.float32) |
| dataset = strategy.experimental_make_numpy_dataset(numpy_input) |
| dist_dataset = strategy.experimental_distribute_dataset(dataset) |
| ``` |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be converted into a |
| dataset. Note that lists of Numpy arrays are stacked, as that is normal |
| `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| return self.extended.experimental_make_numpy_dataset( |
| numpy_input, session=session) |
| |
| def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation |
| """Runs ops in `fn` on each replica, with inputs from `input_iterator`. |
| |
| DEPRECATED: This method is not available in TF 2.x. Please switch |
| to using `experimental_run_v2` instead. |
| |
| When eager execution is enabled, executes ops specified by `fn` on each |
| replica. Otherwise, builds a graph to execute the ops on each replica. |
| |
| Each replica will take a single, different input from the inputs provided by |
| one `get_next` call on the input iterator. |
| |
| `fn` may call `tf.distribute.get_replica_context()` to access members such |
| as `replica_id_in_sync_group`. |
| |
| IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being |
| used, and whether eager execution is enabled, `fn` may be called one or more |
| times (once for each replica). |
| |
| Args: |
| fn: The function to run. The inputs to the function must match the outputs |
| of `input_iterator.get_next()`. The output must be a `tf.nest` of |
| `Tensor`s. |
| input_iterator: (Optional) input iterator from which the inputs are taken. |
| |
| Returns: |
| Merged return value of `fn` across replicas. The structure of the return |
| value is the same as the return value from `fn`. Each element in the |
| structure can either be `PerReplica` (if the values are unsynchronized), |
| `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a |
| single replica). |
| """ |
| return super(StrategyV1, self).experimental_run( |
| fn, input_iterator) |
| |
| def reduce(self, reduce_op, value, axis=None): |
| return super(StrategyV1, self).reduce(reduce_op, value, axis) |
| |
| reduce.__doc__ = Strategy.reduce.__doc__ |
| |
| def update_config_proto(self, config_proto): |
| """Returns a copy of `config_proto` modified for use with this strategy. |
| |
| DEPRECATED: This method is not available in TF 2.x. |
| |
| The updated config has something needed to run a strategy, e.g. |
| configuration to run collective ops, or device filters to improve |
| distributed training performance. |
| |
| Args: |
| config_proto: a `tf.ConfigProto` object. |
| |
| Returns: |
| The updated copy of the `config_proto`. |
| """ |
| return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access |
| |
| |
| # NOTE(josh11b): For any strategy that needs to support tf.compat.v1, |
| # instead descend from StrategyExtendedV1. |
| @tf_export("distribute.StrategyExtended", v1=[]) |
| class StrategyExtendedV2(object): |
| """Additional APIs for algorithms that need to be distribution-aware. |
| |
| Note: For most usage of `tf.distribute.Strategy`, there should be no need to |
| call these methods, since TensorFlow libraries (such as optimizers) already |
| call these methods when needed on your behalf. |
| |
| Lower-level concepts: |
| |
| * Wrapped values: In order to represent values parallel across devices |
| (either replicas or the devices associated with a particular value), we |
| wrap them in a "PerReplica" or "Mirrored" object that contains a map |
| from replica id to values. "PerReplica" is used when the value may be |
| different across replicas, and "Mirrored" when the value are the same. |
| * Unwrapping and merging: Consider calling a function `fn` on multiple |
| replicas, like `experimental_run_v2(fn, args=[w])` with an |
| argument `w` that is a wrapped value. This means `w` will have a map taking |
| replica id `0` to `w0`, replica id `11` to `w1`, etc. |
| `experimental_run_v2()` unwraps `w` before calling `fn`, so |
| it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return |
| values from `fn()`, which can possibly result in wrapped values. For |
| example, let's say `fn()` returns a tuple with three components: `(x, a, |
| v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component |
| is the same object `x` from every replica, then the first component of the |
| merged result will also be `x`. If the second component is different (`a`, |
| `b`, ...) from each replica, then the merged value will have a wrapped map |
| from replica device to the different values. If the third component is the |
| members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.), |
| then the merged result will be that mirrored variable (`v`). |
| * Worker devices vs. parameter devices: Most replica computations will |
| happen on worker devices. Since we don't yet support model |
| parallelism, there will be one worker device per replica. When using |
| parameter servers or central storage, the set of devices holding |
| variables may be different, otherwise the parameter devices might |
| match the worker devices. |
| |
| *Replica context vs. Cross-replica context* |
| |
| _replica context_ is when we are in some function that is being called once |
| for each replica. Otherwise we are in cross-replica context, which is |
| useful for calling `tf.distribute.Strategy` methods which operate across the |
| replicas (like `reduce_to()`). By default you start in a replica context |
| (the "default single replica context") and then some methods can switch you |
| back and forth. There is a third mode you can be in called _update context_ |
| used when updating variables. |
| |
| * `tf.distribute.Strategy.scope`: enters cross-replica context when |
| no other strategy is in scope. |
| * `tf.distribute.Strategy.experimental_run_v2`: calls a function in |
| replica context. |
| * `tf.distribute.ReplicaContext.merge_call`: transitions from replica |
| context to cross-replica context. |
| * `tf.distribute.StrategyExtended.update`: calls a function in an update |
| context from a cross-replica context. |
| |
| In a replica context, you may freely read the values of variables, but |
| you may only update their value if they specify a way to aggregate the |
| update using the `aggregation` parameter in the variable's constructor. |
| In a cross-replica context, you may read or write variables (writes may |
| need to be broadcast to all copies of the variable if it is mirrored). |
| |
| *Sync on read variables* |
| |
| In some cases, such as a metric, we want to accumulate a bunch of updates on |
| each replica independently and only aggregate when reading. This can be a big |
| performance win when the value is read only rarely (maybe the value is only |
| read at the end of an epoch or when checkpointing). These are variables |
| created by passing `synchronization=ON_READ` to the variable's constructor |
| (and some value for `aggregation`). |
| |
| The strategy may choose to put the variable on multiple devices, like mirrored |
| variables, but unlike mirrored variables we don't synchronize the updates to |
| them to make sure they have the same value. Instead, the synchronization is |
| performed when reading in cross-replica context. In a replica context, reads |
| and writes are performed on the local copy (we allow reads so you can write |
| code like `v = 0.9*v + 0.1*update`). We don't allow operations like |
| `v.assign_add` in a cross-replica context for sync on read variables; right |
| now we don't have a use case for such updates and depending on the aggregation |
| mode such updates may not be sensible. |
| |
| *Locality* |
| |
| Depending on how a value is produced, it will have a type that will determine |
| how it may be used. |
| |
| "Per-replica" values exist on the worker devices, with a different value for |
| each replica. They are produced by iterating through a "distributed `Dataset`" |
| returned by `tf.distribute.Strategy.experimental_distribute_dataset` and |
| `tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They |
| are also the typical result returned by |
| `tf.distribute.Strategy.experimental_run_v2`. You typically can't use a |
| per-replica value directly in a cross-replica context, without first resolving |
| how to aggregate the values across replicas, for instance by using |
| `tf.distribute.Strategy.reduce`. |
| |
| "Mirrored" values are like per-replica values, except we know that the value |
| on all replicas are the same. We can safely read a mirrored value in a |
| cross-replica context by using the value on any replica. You can convert |
| a per-replica value into a mirrored value by using |
| `tf.distribute.ReplicaContext.all_reduce`. |
| |
| Values can also have the same locality as a variable, which is a mirrored |
| value but residing on the same devices as the variable (as opposed to the |
| compute devices). Such values may be passed to a call to |
| `tf.distribute.StrategyExtended.update` to update the value of a variable. |
| You may use `tf.distribute.StrategyExtended.colocate_vars_with` to give a |
| variable the same locality as another variable. This is useful, for example, |
| for "slot" variables used by an optimizer for keeping track of statistics |
| used to update a primary/model variable. You may convert a per-replica |
| value to a variable's locality by using |
| `tf.distribute.StrategyExtended.reduce_to` or |
| `tf.distribute.StrategyExtended.batch_reduce_to`. |
| |
| In addition to slot variables which should be colocated with their primary |
| variables, optimizers also define non-slot variables. These can be things like |
| "number of step updates performed" or "beta1^t" and "beta2^t". Each strategy |
| has some policy for which devices those variables should be copied too, called |
| the "non-slot devices" (some subset of the parameter devices). We require that |
| all non-slot variables are allocated on the same device, or mirrored across |
| the same set of devices. You can use |
| `tf.distribute.StrategyExtended.non_slot_devices` to pick a consistent set of |
| devices to pass to both `tf.distribute.StrategyExtended.colocate_vars_with` |
| and `tf.distribute.StrategyExtended.update_non_slot`. |
| |
| *How to update a variable* |
| |
| The standard pattern for updating variables is to: |
| |
| 1. In your function passed to `tf.distribute.Strategy.experimental_run_v2`, |
| compute a list of (update, variable) pairs. For example, the update might |
| be a the gradient of the loss with respect to the variable. |
| 2. Switch to cross-replica mode by calling |
| `tf.distribute.get_replica_context().merge_call()` with the updates and |
| variables as arguments. |
| 3. Call |
| `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` |
| (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` |
| (for a list of variables) to sum the updates. |
| and broadcast the result to the variable's devices. |
| 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update |
| its value. |
| |
| Steps 2 through 4 are done automatically by class |
| `tf.keras.optimizers.Optimizer` if you call its |
| `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. |
| They are also done automatically if you call an `assign*` method on a (non |
| sync-on-read) variable that was constructed with an aggregation method (which |
| is used to determine the reduction used in step 3). |
| |
| *Distribute-aware layers* |
| |
| Layers are generally called in a replica context, except when defining a |
| functional model. `tf.distribute.in_cross_replica_context` will let you |
| determine which case you are in. If in a replica context, |
| the `tf.distribute.get_replica_context` function will return a |
| `tf.distribute.ReplicaContext` object. The `ReplicaContext` object has an |
| `all_reduce` method for aggregating across all replicas. Alternatively, you |
| can update variables following steps 2-4 above. |
| |
| Note: For new `tf.distribute.Strategy` implementations, please put all logic |
| in a subclass of `tf.distribute.StrategyExtended`. The only code needed for |
| the `tf.distribute.Strategy` subclass is for instantiating your subclass of |
| `tf.distribute.StrategyExtended` in the `__init__` method. |
| """ |
| |
| def __init__(self, container_strategy): |
| self._container_strategy_weakref = weakref.ref(container_strategy) |
| self._default_device = None |
| # This property is used to determine if we should set drop_remainder=True |
| # when creating Datasets from numpy array inputs. |
| self._require_static_shapes = False |
| |
| def _container_strategy(self): |
| """Get the containing `tf.distribute.Strategy`. |
| |
| This should not generally be needed except when creating a new |
| `ReplicaContext` and to validate that the caller is in the correct |
| `scope()`. |
| |
| Returns: |
| The `tf.distribute.Strategy` such that `strategy.extended` is `self`. |
| """ |
| container_strategy = self._container_strategy_weakref() |
| assert container_strategy is not None |
| return container_strategy |
| |
| def _scope(self, strategy): |
| """Implementation of tf.distribute.Strategy.scope().""" |
| def creator_with_resource_vars(*args, **kwargs): |
| """Variable creator to use in `_CurrentDistributionContext`.""" |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["distribute_strategy"] = strategy |
| |
| # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid |
| # dereferencing a `Tensor` that is without a `name`. |
| # TODO(b/138130844): Revisit the following check once |
| # `CheckpointInitialValue` class is removed. |
| if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): |
| kwargs["initial_value"] = kwargs["initial_value"].wrapped_value |
| |
| return self._create_variable(*args, **kwargs) |
| |
| def distributed_getter(getter, *args, **kwargs): |
| if not self._allow_variable_partition(): |
| if kwargs.pop("partitioner", None) is not None: |
| tf_logging.log_first_n( |
| tf_logging.WARN, "Partitioned variables are disabled when using " |
| "current tf.distribute.Strategy.", 1) |
| return getter(*args, **kwargs) |
| |
| return _CurrentDistributionContext( |
| strategy, |
| variable_scope.variable_creator_scope(creator_with_resource_vars), |
| variable_scope.variable_scope( |
| variable_scope.get_variable_scope(), |
| custom_getter=distributed_getter), self._default_device) |
| |
| def _allow_variable_partition(self): |
| return False |
| |
| def _create_variable(self, next_creator, *args, **kwargs): |
| # Note: should support "colocate_with" argument. |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def variable_created_in_scope(self, v): |
| """Tests whether `v` was created while this strategy scope was active. |
| |
| Variables created inside the strategy scope are "owned" by it: |
| |
| ```python |
| strategy = tf.distribute.StrategyExtended() |
| with strategy.scope(): |
| v = tf.Variable(1.) |
| strategy.variable_created_in_scope(v) |
| True |
| ``` |
| |
| Variables created outside the strategy are not owned by it: |
| |
| ```python |
| v = tf.Variable(1.) |
| strategy.variable_created_in_scope(v) |
| False |
| ``` |
| |
| Args: |
| v: A `tf.Variable` instance. |
| |
| Returns: |
| True if `v` was created inside the scope, False if not. |
| """ |
| return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Scope that controls which devices variables will be created on. |
| |
| No operations should be added to the graph inside this scope, it |
| should only be used when creating variables (some implementations |
| work by changing variable creation, others work by using a |
| tf.compat.v1.colocate_with() scope). |
| |
| This may only be used inside `self.scope()`. |
| |
| Example usage: |
| |
| ``` |
| with strategy.scope(): |
| var1 = tf.Variable(...) |
| with strategy.extended.colocate_vars_with(var1): |
| # var2 and var3 will be created on the same device(s) as var1 |
| var2 = tf.Variable(...) |
| var3 = tf.Variable(...) |
| |
| def fn(v1, v2, v3): |
| # operates on v1 from var1, v2 from var2, and v3 from var3 |
| |
| # `fn` runs on every device `var1` is on, `var2` and `var3` will be there |
| # too. |
| strategy.extended.update(var1, fn, args=(var2, var3)) |
| ``` |
| |
| Args: |
| colocate_with_variable: A variable created in this strategy's `scope()`. |
| Variables created while in the returned context manager will be on the |
| same set of devices as `colocate_with_variable`. |
| |
| Returns: |
| A context manager. |
| """ |
| def create_colocated_variable(next_creator, *args, **kwargs): |
| _require_strategy_scope_extended(self) |
| kwargs["use_resource"] = True |
| kwargs["colocate_with"] = colocate_with_variable |
| return next_creator(*args, **kwargs) |
| |
| _require_strategy_scope_extended(self) |
| self._validate_colocate_with_variable(colocate_with_variable) |
| return variable_scope.variable_creator_scope(create_colocated_variable) |
| |
| def _validate_colocate_with_variable(self, colocate_with_variable): |
| """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" |
| pass |
| |
| def _make_dataset_iterator(self, dataset): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _make_input_fn_iterator(self, input_fn, replication_mode): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_dataset(self, dataset): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _experimental_distribute_datasets_from_function(self, dataset_fn): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _reduce(self, reduce_op, value): |
| # Default implementation until we have an implementation for each strategy. |
| return self._local_results( |
| self._reduce_to(reduce_op, value, |
| device_util.current() or "/device:CPU:0"))[0] |
| |
| def reduce_to(self, reduce_op, value, destinations): |
| """Combine (via e.g. sum or mean) values across replicas. |
| |
| Args: |
| reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. |
| value: A per-replica value with one value per replica. |
| destinations: A mirrored variable, a per-replica tensor, or a device |
| string. The return value will be copied to all destination devices (or |
| all the devices where the `destinations` value resides). To perform an |
| all-reduction, pass `value` to `destinations`. |
| |
| Returns: |
| A tensor or value mirrored to `destinations`. |
| """ |
| # TODO(josh11b): More docstring |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| assert (reduce_op == reduce_util.ReduceOp.SUM or |
| reduce_op == reduce_util.ReduceOp.MEAN) |
| return self._reduce_to(reduce_op, value, destinations) |
| |
| def _reduce_to(self, reduce_op, value, destinations): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def batch_reduce_to(self, reduce_op, value_destination_pairs): |
| """Combine multiple `reduce_to` calls into one for faster execution. |
| |
| Args: |
| reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. |
| value_destination_pairs: A sequence of (value, destinations) |
| pairs. See `reduce_to()` for a description. |
| |
| Returns: |
| A list of mirrored values, one per pair in `value_destination_pairs`. |
| """ |
| # TODO(josh11b): More docstring |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(reduce_op, variable_scope.VariableAggregation) |
| if isinstance(reduce_op, six.string_types): |
| reduce_op = reduce_util.ReduceOp(reduce_op.upper()) |
| return self._batch_reduce_to(reduce_op, value_destination_pairs) |
| |
| def _batch_reduce_to(self, reduce_op, value_destination_pairs): |
| return [ |
| self.reduce_to(reduce_op, t, destinations=v) |
| for t, v in value_destination_pairs |
| ] |
| |
| def update(self, var, fn, args=(), kwargs=None, group=True): |
| """Run `fn` to update `var` using inputs mirrored to the same devices. |
| |
| If `var` is mirrored across multiple devices, then this implements |
| logic like: |
| |
| ``` |
| results = {} |
| for device, v in var: |
| with tf.device(device): |
| # args and kwargs will be unwrapped if they are mirrored. |
| results[device] = fn(v, *args, **kwargs) |
| return merged(results) |
| ``` |
| |
| Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`. |
| |
| Neither `args` nor `kwargs` may contain per-replica values. |
| If they contain mirrored values, they will be unwrapped before |
| calling `fn`. |
| |
| Args: |
| var: Variable, possibly mirrored to multiple devices, to operate on. |
| fn: Function to call. Should take the variable as the first argument. |
| args: Tuple or list. Additional positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| By default, the merged return value of `fn` across all replicas. The |
| merged result has dependencies to make sure that if it is evaluated at |
| all, the side effects (updates) will happen on every replica. If instead |
| "group=False" is specified, this function will return a nest of lists |
| where each list has an element per replica, and the caller is responsible |
| for ensuring all elements are executed. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| with self._container_strategy().scope(): |
| return self._update(var, fn, args, kwargs, group) |
| |
| def _update(self, var, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def update_non_slot( |
| self, colocate_with, fn, args=(), kwargs=None, group=True): |
| """Runs `fn(*args, **kwargs)` on `colocate_with` devices. |
| |
| Args: |
| colocate_with: The return value of `non_slot_devices()`. |
| fn: Function to execute. |
| args: Tuple or list. Positional arguments to pass to `fn()`. |
| kwargs: Dict with keyword arguments to pass to `fn()`. |
| group: Boolean. Defaults to True. If False, the return value will be |
| unwrapped. |
| |
| Returns: |
| Return value of `fn`, possibly merged across devices. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| with self._container_strategy().scope(): |
| return self._update_non_slot(colocate_with, fn, args, kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, group): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _local_results(self, distributed_value): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def value_container(self, value): |
| """Returns the container that this per-replica `value` belongs to. |
| |
| Args: |
| value: A value returned by `experimental_run_v2()` or a variable |
| created in `scope()`. |
| |
| Returns: |
| A container that `value` belongs to. |
| If value does not belong to any container (including the case of |
| container having been destroyed), returns the value itself. |
| `value in experimental_local_results(value_container(value))` will |
| always be true. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _group(self, value, name=None): |
| """Implementation of `group`.""" |
| value = nest.flatten(self._local_results(value)) |
| |
| if len(value) != 1 or name is not None: |
| return control_flow_ops.group(value, name=name) |
| # Special handling for the common case of one op. |
| v, = value |
| if hasattr(v, "op"): |
| v = v.op |
| return v |
| |
| @property |
| def experimental_require_static_shapes(self): |
| """Returns `True` if static shape is required; `False` otherwise.""" |
| return self._require_static_shapes |
| |
| @property |
| def _num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def worker_devices(self): |
| """Returns the tuple of all devices used to for compute replica execution. |
| """ |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def parameter_devices(self): |
| """Returns the tuple of all devices used to place variables.""" |
| # TODO(josh11b): More docstring |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def non_slot_devices(self, var_list): |
| """Device(s) for non-slot variables. |
| |
| Create variables on these devices in a |
| `with colocate_vars_with(non_slot_devices(...)):` block. |
| Update those using `update_non_slot()`. |
| |
| Args: |
| var_list: The list of variables being optimized, needed with the |
| default `tf.distribute.Strategy`. |
| Returns: |
| A sequence of devices for non-slot variables. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def _configure(self, |
| session_config=None, |
| cluster_spec=None, |
| task_type=None, |
| task_id=None): |
| """Configures the strategy class.""" |
| del session_config, cluster_spec, task_type, task_id |
| |
| def _update_config_proto(self, config_proto): |
| return copy.deepcopy(config_proto) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings. |
| |
| Multi-worker training refers to the setup where the training is |
| distributed across multiple workers, as opposed to the case where |
| only a local process performs the training. This function is |
| used by higher-level apis such as Keras' `model.fit()` to infer |
| for example whether or not a distribute coordinator should be run, |
| and thus TensorFlow servers should be started for communication |
| with other servers in the cluster, or whether or not saving/restoring |
| checkpoints is relevant for preemption fault tolerance. |
| |
| Subclasses should override this to provide whether the strategy is |
| currently in multi-worker setup. |
| |
| Experimental. Signature and implementation are subject to change. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| @tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring |
| class StrategyExtendedV1(StrategyExtendedV2): |
| |
| __doc__ = StrategyExtendedV2.__doc__ |
| |
| def experimental_make_numpy_dataset(self, numpy_input, session=None): |
| """Makes a dataset for input provided via a numpy array. |
| |
| This avoids adding `numpy_input` as a large constant in the graph, |
| and copies the data to the machine or machines that will be processing |
| the input. |
| |
| Args: |
| numpy_input: A nest of NumPy input arrays that will be distributed evenly |
| across all replicas. Note that lists of Numpy arrays are stacked, as |
| that is normal `tf.data.Dataset` behavior. |
| session: (TensorFlow v1.x graph execution only) A session used for |
| initialization. |
| |
| Returns: |
| A `tf.data.Dataset` representing `numpy_input`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| return self._experimental_make_numpy_dataset(numpy_input, session=session) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def broadcast_to(self, tensor, destinations): |
| """Mirror a tensor on one device to all worker devices. |
| |
| Args: |
| tensor: A Tensor value to broadcast. |
| destinations: A mirrored variable or device string specifying the |
| destination devices to copy `tensor` to. |
| |
| Returns: |
| A value mirrored to `destinations` devices. |
| """ |
| assert destinations is not None # from old strategy.broadcast() |
| # TODO(josh11b): More docstring |
| _require_cross_replica_or_default_context_extended(self) |
| assert not isinstance(destinations, (list, tuple)) |
| return self._broadcast_to(tensor, destinations) |
| |
| def _broadcast_to(self, tensor, destinations): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def experimental_run_steps_on_iterator(self, |
| fn, |
| iterator, |
| iterations=1, |
| initial_loop_values=None): |
| """DEPRECATED: please use `experimental_run_v2` instead. |
| |
| Run `fn` with input from `iterator` for `iterations` times. |
| |
| This method can be used to run a step function for training a number of |
| times using input from a dataset. |
| |
| Args: |
| fn: function to run using this distribution strategy. The function must |
| have the following signature: `def fn(context, inputs)`. `context` is an |
| instance of `MultiStepContext` that will be passed when `fn` is run. |
| `context` can be used to specify the outputs to be returned from `fn` |
| by calling `context.set_last_step_output`. It can also be used to |
| capture non tensor outputs by `context.set_non_tensor_output`. See |
| `MultiStepContext` documentation for more information. `inputs` will |
| have same type/structure as `iterator.get_next()`. Typically, `fn` |
| will use `call_for_each_replica` method of the strategy to distribute |
| the computation over multiple replicas. |
| iterator: Iterator of a dataset that represents the input for `fn`. The |
| caller is responsible for initializing the iterator as needed. |
| iterations: (Optional) Number of iterations that `fn` should be run. |
| Defaults to 1. |
| initial_loop_values: (Optional) Initial values to be passed into the |
| loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove |
| initial_loop_values argument when we have a mechanism to infer the |
| outputs of `fn`. |
| |
| Returns: |
| Returns the `MultiStepContext` object which has the following properties, |
| among other things: |
| - run_op: An op that runs `fn` `iterations` times. |
| - last_step_outputs: A dictionary containing tensors set using |
| `context.set_last_step_output`. Evaluating this returns the value of |
| the tensors after the last iteration. |
| - non_tensor_outputs: A dictionatry containing anything that was set by |
| `fn` by calling `context.set_non_tensor_output`. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| with self._container_strategy().scope(): |
| return self._experimental_run_steps_on_iterator(fn, iterator, iterations, |
| initial_loop_values) |
| |
| def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, |
| initial_loop_values): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def call_for_each_replica(self, fn, args=(), kwargs=None): |
| """Run `fn` once per replica. |
| |
| `fn` may call `tf.get_replica_context()` to access methods such as |
| `replica_id_in_sync_group` and `merge_call()`. |
| |
| `merge_call()` is used to communicate between the replicas and |
| re-enter the cross-replica context. All replicas pause their execution |
| having encountered a `merge_call()` call. After that the |
| `merge_fn`-function is executed. Its results are then unwrapped and |
| given back to each replica call. After that execution resumes until |
| `fn` is complete or encounters another `merge_call()`. Example: |
| |
| ```python |
| # Called once in "cross-replica" context. |
| def merge_fn(distribution, three_plus_replica_id): |
| # sum the values across replicas |
| return sum(distribution.experimental_local_results(three_plus_replica_id)) |
| |
| # Called once per replica in `distribution`, in a "replica" context. |
| def fn(three): |
| replica_ctx = tf.get_replica_context() |
| v = three + replica_ctx.replica_id_in_sync_group |
| # Computes the sum of the `v` values across all replicas. |
| s = replica_ctx.merge_call(merge_fn, args=(v,)) |
| return s + v |
| |
| with distribution.scope(): |
| # in "cross-replica" context |
| ... |
| merged_results = distribution.experimental_run_v2(fn, args=[3]) |
| # merged_results has the values from every replica execution of `fn`. |
| # This statement prints a list: |
| print(distribution.experimental_local_results(merged_results)) |
| ``` |
| |
| Args: |
| fn: function to run (will be run once per replica). |
| args: Tuple or list with positional arguments for `fn`. |
| kwargs: Dict with keyword arguments for `fn`. |
| |
| Returns: |
| Merged return value of `fn` across all replicas. |
| """ |
| _require_cross_replica_or_default_context_extended(self) |
| if kwargs is None: |
| kwargs = {} |
| with self._container_strategy().scope(): |
| return self._call_for_each_replica(fn, args, kwargs) |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| raise NotImplementedError("must be implemented in descendants") |
| |
| def read_var(self, v): |
| """Reads the value of a variable. |
| |
| Returns the aggregate value of a replica-local variable, or the |
| (read-only) value of any other variable. |
| |
| Args: |
| v: A variable allocated within the scope of this `tf.distribute.Strategy`. |
| |
| Returns: |
| A tensor representing the value of `v`, aggregated across replicas if |
| necessary. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def experimental_between_graph(self): |
| """Whether the strategy uses between-graph replication or not. |
| |
| This is expected to return a constant value that will not be changed |
| throughout its life cycle. |
| """ |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def experimental_should_init(self): |
| """Whether initialization is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_checkpoint(self): |
| """Whether checkpointing is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| @property |
| def should_save_summary(self): |
| """Whether saving summaries is needed.""" |
| raise NotImplementedError("must be implemented in descendants") |
| |
| |
| # A note about the difference between the context managers |
| # `ReplicaContext` (defined here) and `_CurrentDistributionContext` |
| # (defined above) used by `tf.distribute.Strategy.scope()`: |
| # |
| # * a ReplicaContext is only present during a `experimental_run_v2()` |
| # call (except during a `merge_run` call) and in such a scope it |
| # will be returned by calls to `get_replica_context()`. Implementers of new |
| # Strategy descendants will frequently also need to |
| # define a descendant of ReplicaContext, and are responsible for |
| # entering and exiting this context. |
| # |
| # * Strategy.scope() sets up a variable_creator scope that |
| # changes variable creation calls (e.g. to make mirrored |
| # variables). This is intended as an outer scope that users enter once |
| # around their model creation and graph definition. There is no |
| # anticipated need to define descendants of _CurrentDistributionContext. |
| # It sets the current Strategy for purposes of |
| # `get_strategy()` and `has_strategy()` |
| # and switches the thread mode to a "cross-replica context". |
| @tf_export("distribute.ReplicaContext") |
| class ReplicaContext(object): |
| """`tf.distribute.Strategy` API when in a replica context. |
| |
| You can use `tf.distribute.get_replica_context` to get an instance of |
| `ReplicaContext`. This should be inside your replicated step function, such |
| as in a `tf.distribute.Strategy.experimental_run_v2` call. |
| """ |
| |
| def __init__(self, strategy, replica_id_in_sync_group): |
| self._strategy = strategy |
| self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access |
| self) |
| self._replica_id_in_sync_group = replica_id_in_sync_group |
| self._summary_recording_distribution_strategy = None |
| |
| def __enter__(self): |
| _push_per_thread_mode(self._thread_context) |
| |
| def replica_id_is_zero(): |
| return math_ops.equal(self._replica_id_in_sync_group, |
| constant_op.constant(0)) |
| |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| self._summary_recording_distribution_strategy = ( |
| summary_state.is_recording_distribution_strategy) |
| summary_state.is_recording_distribution_strategy = replica_id_is_zero |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access |
| summary_state.is_recording_distribution_strategy = ( |
| self._summary_recording_distribution_strategy) |
| _pop_per_thread_mode() |
| |
| def merge_call(self, merge_fn, args=(), kwargs=None): |
| """Merge args across replicas and run `merge_fn` in a cross-replica context. |
| |
| This allows communication and coordination when there are multiple calls |
| to the step_fn triggered by a call to |
| `strategy.experimental_run_v2(step_fn, ...)`. |
| |
| See `tf.distribute.Strategy.experimental_run_v2` for an |
| explanation. |
| |
| If not inside a distributed scope, this is equivalent to: |
| |
| ``` |
| strategy = tf.distribute.get_strategy() |
| with cross-replica-context(strategy): |
| return merge_fn(strategy, *args, **kwargs) |
| ``` |
| |
| Args: |
| merge_fn: Function that joins arguments from threads that are given as |
| PerReplica. It accepts `tf.distribute.Strategy` object as |
| the first argument. |
| args: List or tuple with positional per-thread arguments for `merge_fn`. |
| kwargs: Dict with keyword per-thread arguments for `merge_fn`. |
| |
| Returns: |
| The return value of `merge_fn`, except for `PerReplica` values which are |
| unpacked. |
| """ |
| require_replica_context(self) |
| if kwargs is None: |
| kwargs = {} |
| return self._merge_call(merge_fn, args, kwargs) |
| |
| def _merge_call(self, merge_fn, args, kwargs): |
| """Default implementation for single replica.""" |
| _push_per_thread_mode( # thread-local, so not needed with multiple threads |
| distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access |
| try: |
| return merge_fn(self._strategy, *args, **kwargs) |
| finally: |
| _pop_per_thread_mode() |
| |
| @property |
| def num_replicas_in_sync(self): |
| """Returns number of replicas over which gradients are aggregated.""" |
| return self._strategy.num_replicas_in_sync |
| |
| @property |
| def replica_id_in_sync_group(self): |
| """Returns the id of the replica being defined. |
| |
| This identifies the replica that is part of a sync group. Currently we |
| assume that all sync groups contain the same number of replicas. The value |
| of the replica id can range from 0 to `num_replica_in_sync` - 1. |
| """ |
| require_replica_context(self) |
| return self._replica_id_in_sync_group |
| |
| @property |
| def strategy(self): |
| """The current `tf.distribute.Strategy` object.""" |
| return self._strategy |
| |
| @property |
| def devices(self): |
| """The devices this replica is to be executed on, as a tuple of strings.""" |
| require_replica_context(self) |
| return (device_util.current(),) |
| |
| def all_reduce(self, reduce_op, value): |
| """All-reduces the given `value Tensor` nest across replicas. |
| |
| If `all_reduce` is called in any replica, it must be called in all replicas. |
| The nested structure and `Tensor` shapes must be identical in all replicas. |
| |
| IMPORTANT: The ordering of communications must be identical in all replicas. |
| |
| Example with two replicas: |
| Replica 0 `value`: {'a': 1, 'b': [40, 1]} |
| Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} |
| |
| If `reduce_op` == `SUM`: |
| Result (on all replicas): {'a': 4, 'b': [42, 99]} |
| |
| If `reduce_op` == `MEAN`: |
| Result (on all replicas): {'a': 2, 'b': [21, 49.5]} |
| |
| Args: |
| reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. |
| value: The nested structure of `Tensor`s to all-reduce. The structure must |
| be compatible with `tf.nest`. |
| |
| Returns: |
| A `Tensor` nest with the reduced `value`s from each replica. |
| """ |
| def batch_all_reduce(strategy, *value_flat): |
| return strategy.extended.batch_reduce_to( |
| reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) |
| |
| if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: |
| # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. |
| @custom_gradient.custom_gradient |
| def grad_wrapper(*xs): |
| ys = self.merge_call(batch_all_reduce, args=xs) |
| # The gradient of an all-sum is itself an all-sum (all-mean, likewise). |
| return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) |
| return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) |
| else: |
| # TODO(cjfj): Implement gradients for other reductions. |
| reduced = nest.pack_sequence_as( |
| value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) |
| return nest.map_structure(array_ops.prevent_gradient, reduced) |
| |
| # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient |
| # all-reduce. It would return a function returning the result of reducing `t` |
| # across all replicas. The caller would wait to call this function until they |
| # needed the reduce result, allowing an efficient implementation: |
| # * With eager execution, the reduction could be performed asynchronously |
| # in the background, not blocking until the result was needed. |
| # * When constructing a graph, it could batch up all reduction requests up |
| # to that point that the first result is needed. Most likely this can be |
| # implemented in terms of `merge_call()` and `batch_reduce_to()`. |
| |
| |
| def _batch_reduce_destination(x): |
| """Returns the destinations for batch all-reduce.""" |
| if isinstance(x, ops.Tensor): |
| # If this is a one device strategy. |
| return x.device |
| else: |
| return x |
| |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| _creating_default_strategy_singleton = False |
| |
| |
| class _DefaultDistributionStrategy(StrategyV1): |
| """Default `tf.distribute.Strategy` if none is explicitly selected.""" |
| |
| def __init__(self): |
| if not _creating_default_strategy_singleton: |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| super(_DefaultDistributionStrategy, self).__init__( |
| _DefaultDistributionExtended(self)) |
| |
| def __deepcopy__(self, memo): |
| del memo |
| raise RuntimeError("Should only create a single instance of " |
| "_DefaultDistributionStrategy") |
| |
| |
| class _DefaultDistributionContext(object): |
| """Context manager setting the default `tf.distribute.Strategy`.""" |
| |
| def __init__(self, strategy): |
| |
| def creator(next_creator, *args, **kwargs): |
| _require_strategy_scope_strategy(strategy) |
| return next_creator(*args, **kwargs) |
| |
| self._var_creator_scope = variable_scope.variable_creator_scope(creator) |
| self._strategy = strategy |
| self._nested_count = 0 |
| |
| def __enter__(self): |
| # Allow this scope to be entered if this strategy is already in scope. |
| if distribution_strategy_context.has_strategy(): |
| raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") |
| if self._nested_count == 0: |
| self._var_creator_scope.__enter__() |
| self._nested_count += 1 |
| return self._strategy |
| |
| def __exit__(self, exception_type, exception_value, traceback): |
| self._nested_count -= 1 |
| if self._nested_count == 0: |
| try: |
| self._var_creator_scope.__exit__( |
| exception_type, exception_value, traceback) |
| except RuntimeError as e: |
| six.raise_from( |
| RuntimeError("Variable creator scope nesting error: move call to " |
| "tf.distribute.set_strategy() out of `with` scope."), |
| e) |
| |
| |
| class _DefaultDistributionExtended(StrategyExtendedV1): |
| """Implementation of _DefaultDistributionStrategy.""" |
| |
| def __init__(self, container_strategy): |
| super(_DefaultDistributionExtended, self).__init__(container_strategy) |
| self._retrace_functions_for_each_device = False |
| |
| def _scope(self, strategy): |
| """Context manager setting a variable creator and `self` as current.""" |
| return _DefaultDistributionContext(strategy) |
| |
| def colocate_vars_with(self, colocate_with_variable): |
| """Does not require `self.scope`.""" |
| _require_strategy_scope_extended(self) |
| return ops.colocate_with(colocate_with_variable) |
| |
| def variable_created_in_scope(self, v): |
| return v._distribute_strategy is None # pylint: disable=protected-access |
| |
| def _experimental_distribute_dataset(self, dataset): |
| return dataset |
| |
| def _experimental_distribute_datasets_from_function(self, dataset_fn): |
| return dataset_fn(InputContext()) |
| |
| def _make_dataset_iterator(self, dataset): |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _make_input_fn_iterator(self, |
| input_fn, |
| replication_mode=InputReplicationMode.PER_WORKER): |
| dataset = input_fn(InputContext()) |
| return _DefaultDistributionExtended.DefaultInputIterator(dataset) |
| |
| def _experimental_make_numpy_dataset(self, numpy_input, session): |
| numpy_flat = nest.flatten(numpy_input) |
| vars_flat = tuple( |
| variable_scope.variable(array_ops.zeros(i.shape, i.dtype), |
| trainable=False, use_resource=True) |
| for i in numpy_flat |
| ) |
| for v, i in zip(vars_flat, numpy_flat): |
| numpy_dataset.init_var_from_numpy(v, i, session) |
| vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) |
| return dataset_ops.Dataset.from_tensor_slices(vars_nested) |
| |
| def _broadcast_to(self, tensor, destinations): |
| if destinations is None: |
| return tensor |
| else: |
| raise NotImplementedError("TODO") |
| |
| def _call_for_each_replica(self, fn, args, kwargs): |
| with ReplicaContext( |
| self._container_strategy(), |
| replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): |
| return fn(*args, **kwargs) |
| |
| def _reduce_to(self, reduce_op, value, destinations): |
| # TODO(josh11b): Use destinations? |
| del reduce_op, destinations |
| return value |
| |
| def _update(self, var, fn, args, kwargs, group): |
| # The implementations of _update() and _update_non_slot() are identical |
| # except _update() passes `var` as the first argument to `fn()`. |
| return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) |
| |
| def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): |
| # TODO(josh11b): Figure out what we should be passing to UpdateContext() |
| # once that value is used for something. |
| with UpdateContext(colocate_with): |
| result = fn(*args, **kwargs) |
| if should_group: |
| return result |
| else: |
| return nest.map_structure(self._local_results, result) |
| |
| def read_var(self, replica_local_var): |
| return array_ops.identity(replica_local_var) |
| |
| def _local_results(self, distributed_value): |
| return (distributed_value,) |
| |
| def value_container(self, value): |
| return value |
| |
| @property |
| def _num_replicas_in_sync(self): |
| return 1 |
| |
| @property |
| def worker_devices(self): |
| raise RuntimeError("worker_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| @property |
| def parameter_devices(self): |
| raise RuntimeError("parameter_devices() method unsupported by default " |
| "tf.distribute.Strategy.") |
| |
| def non_slot_devices(self, var_list): |
| return min(var_list, key=lambda x: x.name) |
| |
| def _in_multi_worker_mode(self): |
| """Whether this strategy indicates working in multi-worker settings.""" |
| # Default strategy doesn't indicate multi-worker training. |
| return False |
| |
| # TODO(priyag): This should inherit from `InputIterator`, once dependency |
| # issues have been resolved. |
| class DefaultInputIterator(object): |
| """Default implementation of `InputIterator` for default strategy.""" |
| |
| def __init__(self, dataset): |
| self._dataset = dataset |
| if eager_context.executing_eagerly(): |
| self._iterator = dataset_ops.make_one_shot_iterator(dataset) |
| else: |
| self._iterator = dataset_ops.make_initializable_iterator(dataset) |
| |
| def get_next(self): |
| return self._iterator.get_next() |
| |
| def initialize(self): |
| if eager_context.executing_eagerly(): |
| self._iterator = self._dataset.make_one_shot_iterator() |
| return [] |
| else: |
| return [self._iterator.initializer] |
| |
| # TODO(priyag): Delete this once all strategies use global batch size. |
| @property |
| def _global_batch_size(self): |
| """Global and per-replica batching are equivalent for this strategy.""" |
| return True |
| |
| |
| # ------------------------------------------------------------------------------ |
| # We haven't yet implemented deserialization for DistributedVariables. |
| # So here we catch any attempts to deserialize variables |
| # when using distribution strategies. |
| # pylint: disable=protected-access |
| _original_from_proto = resource_variable_ops._from_proto_fn |
| |
| |
| def _from_proto_fn(v, import_scope=None): |
| if distribution_strategy_context.has_strategy(): |
| raise NotImplementedError( |
| "Deserialization of variables is not yet supported when using a " |
| "tf.distribute.Strategy.") |
| else: |
| return _original_from_proto(v, import_scope=import_scope) |
| |
| resource_variable_ops._from_proto_fn = _from_proto_fn |
| # pylint: enable=protected-access |
| |
| |
| #------------------------------------------------------------------------------- |
| # Shorthand for some methods from distribution_strategy_context. |
| _push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access |
| _get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access |
| _pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access |
| _get_default_replica_mode = ( |
| distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Metrics to track which distribution strategy is being called |
| distribution_strategy_gauge = monitoring.StringGauge( |
| "/tensorflow/api/distribution_strategy", |
| "Gauge to track the type of distribution strategy used.", "TFVersion") |
| distribution_strategy_replica_gauge = monitoring.IntGauge( |
| "/tensorflow/api/distribution_strategy/replica", |
| "Gauge to track the number of replica each distribution strategy used.", |
| "CountType") |