| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Class implementing utilities used by tf.distribute.Strategy.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from collections import abc |
| |
| from tensorflow.python.distribute import tpu_values as tpu_values_lib |
| from tensorflow.python.distribute import values as values_lib |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import tape |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import variable_scope as vs |
| from tensorflow.python.util import nest |
| |
| |
| def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): |
| """Makes a nest per-replica into a nest of PerReplica/Mirrored values. |
| |
| Args: |
| values: Values to regroup |
| wrap_class: Class that `values` be wrapped in. |
| always_wrap: Always wrap the `values` in `wrap_class` even if the values |
| are the same except for DistributeVariable. |
| Returns: |
| Wrapped `values`. |
| """ |
| v0 = values[0] |
| |
| if isinstance(v0, list): |
| for v in values[1:]: |
| assert isinstance(v, list) |
| assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % |
| (len(v), len(v0), v, v0)) |
| return [ |
| regroup(tuple(v[i] for v in values), wrap_class, always_wrap) |
| for i in range(len(v0)) |
| ] |
| |
| if isinstance(v0, tuple): |
| for v in values[1:]: |
| assert isinstance(v, tuple) |
| assert len(v) == len(v0) |
| regrouped_tuple = tuple( |
| regroup(tuple(v[i] for v in values), wrap_class, always_wrap) |
| for i in range(len(v0))) |
| if hasattr(v0, "_fields"): |
| # This tuple is in fact a namedtuple! Create a new namedtuple instance |
| # and initialize it with the regrouped values: |
| assert hasattr(v0, "_make") |
| return v0._make(regrouped_tuple) |
| else: |
| return regrouped_tuple |
| |
| if isinstance(v0, abc.Mapping): |
| v0keys = v0.keys() |
| for v in values[1:]: |
| assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v)) |
| assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % |
| (set(v0keys), set(v.keys()))) |
| # Use the actual type in case it is a class inherited from a dict. |
| return type(v0)({ |
| key: regroup(tuple(v[key] for v in values), |
| wrap_class, always_wrap) |
| for key in v0keys |
| }) |
| |
| # If exactly the same object across all devices, return it unwrapped. |
| same_id = True |
| for v in values[1:]: |
| if v is not v0: |
| same_id = False |
| break |
| # Consider three cases where same_id is true: |
| # * If v0 is a DistributedVariable (a MirroredVariable or |
| # SyncOnReadVariable, and same_id means it is the same across all |
| # devices), we want to return it. We check DistributedVariable |
| # specifically since it can look like it has a |
| # _distributed_container member since its members do. |
| if same_id and isinstance(v0, values_lib.DistributedVariable): |
| return v0 |
| # * If v0 is a member of a distributed variable, in which case |
| # hasattr(v0, "_distributed_container") is true, we want to |
| # return the DistributedVariable that contains it using the |
| # _distributed_container logic below. This case can trigger |
| # same_id when there is only one device. |
| # * In any other situation, same_id means we return v0 unless `always_wrap` is |
| # true. |
| if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): |
| return v0 |
| |
| # Detect the case where each device has a parallel component of the |
| # same MirroredVariable (or SyncOnReadVariable). In this case we |
| # want to return the containing MirroredVariable, after a bunch of |
| # sanity checking. In particular, each component should have the |
| # same container, and the devices of the variables should match the |
| # keys of the per-replica dictionary. |
| if hasattr(v0, "_distributed_container"): |
| # pylint: disable=protected-access |
| assert not isinstance(v0, values_lib.MirroredVariable), ( |
| "ids = %s, values = %s" % ([id(v) for v in values], values)) |
| distributed_container = v0._distributed_container() |
| assert distributed_container is not None |
| for v in values[1:]: |
| assert distributed_container is v._distributed_container() |
| return distributed_container |
| # pylint: enable=protected-access |
| |
| return wrap_class(values) |
| |
| |
| def select_replica(replica_id, structured): |
| """Specialize a nest of regular & per-replica values for one replica.""" |
| |
| def _get(x): |
| # `DistributedValues` would be sliced according to replica unless it is a |
| # `DistributedVariable` because `DistributedVariable` can be handled |
| # directly in the replica context. |
| if (isinstance(x, values_lib.DistributedVariable) or |
| not isinstance(x, values_lib.DistributedValues)): |
| return x |
| else: |
| return x.values[replica_id] |
| |
| return nest.map_structure(_get, structured) |
| |
| |
| def select_replica_mirrored(replica_id, structured): |
| """Specialize a nest of regular & mirrored values for one replica.""" |
| assert_mirrored(structured) |
| return select_replica(replica_id, structured) |
| |
| |
| def assert_mirrored(structured): |
| """Raises if the structured is not composed of mirrored or regular values.""" |
| |
| def _assert_mirrored(x): |
| if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x): |
| raise TypeError( |
| "Expected value to be mirrored across replicas: %s in %s." % |
| (x, structured)) |
| |
| nest.map_structure(_assert_mirrored, structured) |
| |
| |
| def update_regroup(extended, updates, group): |
| """Regroup for an update, with dependencies to ensure all updates execute.""" |
| if not group: |
| regrouped = regroup(updates, values_lib.Mirrored) |
| return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access |
| |
| def _make_grouped_mirrored(values): |
| """Convert per-replica list `values` into Mirrored type with grouping.""" |
| if len(values) == 1: |
| return values_lib.Mirrored(values) |
| |
| # Make sure we run all updates. Without this, something like |
| # session.run(extended.update(...)) may only update one replica. |
| g = control_flow_ops.group(values) |
| |
| # If values is just ops, the grouping is enough. Everything in values |
| # should have the same type, since we expect every replica to be performing |
| # the same computation. |
| if not all(tensor_util.is_tf_type(v) for v in values): |
| return g |
| |
| # Otherwise we need tensors with the same values as `values`, but |
| # that have a dependency on `g`. |
| with_dep = [] |
| for v in values: |
| with ops.device(v.device), ops.control_dependencies([g]): |
| with_dep.append(array_ops.identity(v)) |
| |
| return values_lib.Mirrored(with_dep) |
| |
| return regroup(updates, _make_grouped_mirrored) |
| |
| |
| def value_container(val): |
| """Returns the container that this per-replica `value` belongs to. |
| |
| Args: |
| val: A value returned by `call_for_each_replica()` or a variable created in |
| `scope()`. |
| |
| Returns: |
| A container that `value` belongs to. |
| If value does not belong to any container (including the case of |
| container having been destroyed), returns the value itself. |
| """ |
| if (hasattr(val, "_distributed_container") and |
| # DistributedVariable has _distributed_container defined |
| # but we don't want to return it. |
| not isinstance(val, values_lib.DistributedVariable)): |
| container = val._distributed_container() # pylint: disable=protected-access |
| if container is not None: |
| return container |
| return val |
| |
| |
| def is_distributed_variable(v): |
| """Determine if a variable is ds variable or TPU mirrored variable.""" |
| return isinstance(v, values_lib.DistributedVariable) |
| |
| |
| def _validate_colocate_extended(v, extended): |
| variable_strategy = v._distribute_strategy # pylint: disable=protected-access |
| if variable_strategy.extended is not extended: |
| raise ValueError( |
| "`colocate_vars_with` must only be passed a variable created in this " |
| "tf.distribute.Strategy.scope(), not %s created in scope: %s" % |
| (v, variable_strategy)) |
| |
| |
| def validate_colocate_distributed_variable(v, extended): |
| if not isinstance(v, values_lib.DistributedVariable): |
| raise ValueError( |
| "`colocate_vars_with` must only be passed a variable created in this " |
| "tf.distribute.Strategy.scope(), not: %r" % (v,)) |
| _validate_colocate_extended(v, extended) |
| |
| |
| def validate_colocate(v, extended): |
| if not hasattr(v, "_distribute_strategy"): |
| raise ValueError( |
| "`colocate_vars_with` must only be passed a variable created in this " |
| "tf.distribute.Strategy.scope(), not: %r" % (v,)) |
| _validate_colocate_extended(v, extended) |
| |
| |
| # Variable creation function for sync strategies. |
| def _validate_synchronization(kwargs): |
| """Validate that given synchronization value is valid.""" |
| synchronization = kwargs.get("synchronization", |
| vs.VariableSynchronization.AUTO) |
| if synchronization == vs.VariableSynchronization.NONE: |
| raise ValueError( |
| "`NONE` variable synchronization mode is not supported with " |
| "tf.distribute strategy. Please change the `synchronization` for " |
| "variable: " + str(kwargs["name"])) |
| if synchronization not in (vs.VariableSynchronization.ON_READ, |
| vs.VariableSynchronization.ON_WRITE, |
| vs.VariableSynchronization.AUTO): |
| raise ValueError( |
| "Invalid variable synchronization mode: %s for variable: %s" % |
| (synchronization, kwargs["name"])) |
| if synchronization == vs.VariableSynchronization.AUTO: |
| return vs.VariableSynchronization.ON_WRITE |
| return synchronization |
| |
| |
| def _validate_aggregation(kwargs): |
| aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE) |
| |
| if aggregation not in (vs.VariableAggregation.NONE, |
| vs.VariableAggregation.SUM, |
| vs.VariableAggregation.MEAN, |
| vs.VariableAggregation.ONLY_FIRST_REPLICA): |
| raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % |
| (aggregation, kwargs["name"])) |
| return aggregation |
| |
| |
| def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, |
| policy_mapping, **kwargs): |
| """Create distributed variables with given synchronization and aggregation.""" |
| # Figure out what collections this variable should be added to. |
| # We'll add the MirroredVariable to those collections instead. |
| var_collections = kwargs.pop("collections", None) |
| if var_collections is None: |
| var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] |
| kwargs["collections"] = [] |
| |
| synchronization = _validate_synchronization(kwargs) |
| # Update synchronization in kwargs in case it's AUTO, which is converted to |
| # ON_WRITE. |
| kwargs["synchronization"] = synchronization |
| aggregation = _validate_aggregation(kwargs) |
| use_var_policy = getattr(strategy.extended, "_use_var_policy", False) |
| |
| # Ignore user-specified caching device, not needed for mirrored variables. |
| kwargs.pop("caching_device", None) |
| |
| # TODO(josh11b,apassos): It would be better if variable initialization |
| # was never recorded on the tape instead of having to do this manually |
| # here. |
| with tape.stop_recording(): |
| value_list = real_mirrored_creator(**kwargs) |
| # MirroredVariable is recreated during saved_model loading, and its |
| # component variables (value_list) will have None initializer. We |
| # set their initializers to no_op so that consumer like |
| # `global_variables_initializer` wouldn't complain, as it groups all |
| # variables' initializers thus all variables have to have initializers. |
| for v in value_list: |
| # pylint:disable=protected-access |
| if v._initializer_op is None: |
| v._initializer_op = control_flow_ops.no_op() |
| # pylint:enable=protected-access |
| if use_var_policy: |
| var_policy_cls = policy_mapping.get(synchronization) |
| var_policy = var_policy_cls(aggregation=aggregation) |
| var_cls = class_mapping.get("VariableClass") |
| result = var_cls(strategy, value_list, aggregation, var_policy=var_policy) |
| else: |
| var_cls = class_mapping.get(synchronization) |
| result = var_cls(strategy, value_list, aggregation) |
| |
| # Add the wrapped variable to the requested collections. |
| # The handling of eager mode and the global step matches |
| # ResourceVariable._init_from_args(). |
| if not context.executing_eagerly(): |
| g = ops.get_default_graph() |
| # If "trainable" is True, next_creator() will add the member variables |
| # to the TRAINABLE_VARIABLES collection, so we manually remove |
| # them and replace with the MirroredVariable. We can't set |
| # "trainable" to False for next_creator() since that causes functions |
| # like implicit_gradients to skip those variables. |
| if kwargs.get("trainable", True): |
| var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) |
| l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) |
| for value in value_list: |
| for i, trainable_variable in enumerate(l): |
| if value is trainable_variable: |
| del l[i] |
| break |
| |
| g.add_to_collections(var_collections, result) |
| elif ops.GraphKeys.GLOBAL_STEP in var_collections: |
| ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) |
| |
| return result |
| |
| |
| # Utility functions |
| # Return True if the Value is Mirrored or the Variable is replicated and kept in |
| # sync. |
| def is_mirrored(val): |
| if isinstance(val, values_lib.DistributedVariable): |
| if val._policy: # pylint: disable=protected-access |
| return val._policy._is_mirrored() # pylint: disable=protected-access |
| return isinstance(val, values_lib.Mirrored) |
| |
| |
| def is_sync_on_read(val): |
| if isinstance(val, values_lib.DistributedVariable): |
| if val._policy: # pylint: disable=protected-access |
| return not val._policy._is_mirrored() # pylint: disable=protected-access |
| return not isinstance(val, values_lib.Mirrored) |
| |
| # The following mapping indicates the policy that you must use for a given |
| # variable `synchronization` and `aggregation` pair. |
| # OnWritePolicy is used for: |
| # (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) |
| # (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) |
| # OnReadPolicy is used for: |
| # (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) |
| VARIABLE_POLICY_MAPPING = { |
| vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy, |
| vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy, |
| } |
| |
| VARIABLE_CLASS_MAPPING = { |
| "VariableClass": values_lib.DistributedVariable, |
| vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, |
| vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable, |
| } |
| |
| TPU_VARIABLE_POLICY_MAPPING = { |
| vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy, |
| vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy, |
| } |
| |
| TPU_VARIABLE_CLASS_MAPPING = { |
| "VariableClass": tpu_values_lib.TPUDistributedVariable, |
| vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable, |
| vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable, |
| } |