Add support for variable policy to be used by MirroredStrategy and TPUStrategy. Refactor existing values test and add an option to test variable policy.
PiperOrigin-RevId: 323304491
Change-Id: I5c00791bc62a930274c254b33f4a47d671d0b7bf
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 185b456..356fb3a 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -302,7 +302,6 @@
":distribute_lib",
":reduce_util",
":shared_variable_creator",
- ":tpu_values",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:config",
@@ -1226,46 +1225,6 @@
)
distribute_py_test(
- name = "vars_test",
- size = "medium",
- srcs = ["vars_test.py"],
- main = "vars_test.py",
- shard_count = 5,
- tags = [
- "multi_and_single_gpu",
- "no_rocm",
- ],
- tpu_tags = [
- "no_oss", # b/150954621 Target too big to run serially reliably.
- ],
- deps = [
- ":combinations",
- ":distribute_lib",
- ":strategy_combinations",
- ":tpu_strategy",
- ":tpu_values",
- ":values",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:checkpoint_management",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:indexed_slices",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/eager:def_function",
- "//tensorflow/python/eager:test",
- "//tensorflow/python/tpu:tpu_lib",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-distribute_py_test(
name = "ps_values_test",
size = "medium",
srcs = ["ps_values_test.py"],
diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py
index a86c751..ad8bb87 100644
--- a/tensorflow/python/distribute/combinations.py
+++ b/tensorflow/python/distribute/combinations.py
@@ -58,17 +58,11 @@
"""
def modified_arguments(self, kwargs, requested_parameters):
- # Get the parameter that indicates if we need to set the `_use_policy` flag
- # on the strategy object. This is a temporary flag for testing the variable
- # policy rollout.
- use_var_policy = kwargs.get("use_var_policy", None)
+ del requested_parameters
distribution_arguments = {}
for k, v in kwargs.items():
if isinstance(v, NamedDistribution):
- strategy = v.strategy
- if use_var_policy:
- strategy.extended._use_var_policy = use_var_policy
- distribution_arguments[k] = strategy
+ distribution_arguments[k] = v.strategy
return distribution_arguments
diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py
index 916ebaf..89848b9 100644
--- a/tensorflow/python/distribute/distribute_utils.py
+++ b/tensorflow/python/distribute/distribute_utils.py
@@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python.distribute import tpu_values as tpu_values_lib
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
@@ -146,7 +145,7 @@
def _get_mirrored(x):
if isinstance(x, values_lib.DistributedValues):
- if not is_mirrored(x):
+ if not isinstance(x, values_lib.Mirrored):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
@@ -246,25 +245,34 @@
# Variable creation function for sync strategies.
-def _get_and_validate_synchronization(kwargs):
- """Validate that given synchronization value is valid."""
+def create_mirrored_variable( # pylint: disable=missing-docstring
+ strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ var_collections = kwargs.pop("collections", None)
+ if var_collections is None:
+ var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
synchronization = kwargs.get("synchronization",
- vs.VariableSynchronization.AUTO)
+ vs.VariableSynchronization.ON_WRITE)
+
if synchronization == vs.VariableSynchronization.NONE:
raise ValueError(
- "`NONE` variable synchronization mode is not supported with "
- "tf.distribute strategy. Please change the `synchronization` for "
+ "`NONE` variable synchronization mode is not supported with `Mirrored` "
+ "distribution strategy. Please change the `synchronization` for "
"variable: " + str(kwargs["name"]))
- if synchronization not in (vs.VariableSynchronization.ON_READ,
- vs.VariableSynchronization.ON_WRITE,
- vs.VariableSynchronization.AUTO):
+ elif synchronization == vs.VariableSynchronization.ON_READ:
+ is_sync_on_read = True
+ elif synchronization in (vs.VariableSynchronization.ON_WRITE,
+ vs.VariableSynchronization.AUTO):
+ # `AUTO` synchronization defaults to `ON_WRITE`.
+ is_sync_on_read = False
+ else:
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
- return synchronization
-
-def _validate_aggregation(kwargs):
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (vs.VariableAggregation.NONE,
@@ -273,33 +281,6 @@
vs.VariableAggregation.ONLY_FIRST_REPLICA):
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
- return aggregation
-
-
-def _get_variable_policy_class(synchronization, aggregation, policy_mapping):
- if synchronization == vs.VariableSynchronization.AUTO:
- if aggregation == vs.VariableAggregation.NONE:
- # Use AutoPolicy.
- return policy_mapping.get(synchronization)
- else:
- # Revert to OnWritePolicy
- return policy_mapping.get(vs.VariableSynchronization.ON_WRITE)
- return policy_mapping.get(synchronization)
-
-
-def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
- policy_mapping, **kwargs):
- """Create distributed variables with given synchronization and aggregation."""
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- var_collections = kwargs.pop("collections", None)
- if var_collections is None:
- var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
- synchronization = _get_and_validate_synchronization(kwargs)
- aggregation = _validate_aggregation(kwargs)
- use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
@@ -309,15 +290,8 @@
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(**kwargs)
- if use_var_policy:
- var_policy_cls = _get_variable_policy_class(synchronization, aggregation,
- policy_mapping)
- var_policy = var_policy_cls(aggregation=aggregation)
- var_cls = class_mapping.get("VariableClass")
- result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
- else:
- var_cls = class_mapping.get(synchronization)
- result = var_cls(strategy, value_list, aggregation)
+ var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
+ result = var_cls(strategy, value_list, aggregation)
# Install the created DistributedVariable as _distributed_container property
# of the underlying variables, to make it easy to map back to the container.
for v in result.values:
@@ -350,55 +324,3 @@
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
-
-
-# Utility functions
-# Return True if the Value is Mirrored or the Variable is replicated and kept in
-# sync.
-def is_mirrored(val):
- if isinstance(val, values_lib.DistributedVariable):
- if val._policy: # pylint: disable=protected-access
- return val._policy._is_mirrored() # pylint: disable=protected-access
- return isinstance(val, values_lib.Mirrored)
-
-
-def is_sync_on_read(val):
- if isinstance(val, values_lib.DistributedVariable):
- if val._policy: # pylint: disable=protected-access
- return not val._policy._is_mirrored() # pylint: disable=protected-access
- return not isinstance(val, values_lib.Mirrored)
-
-# The following mapping indicates the policy that you must use for a given
-# variable `synchronization` and `aggregation` pair.
-# AutoPolicy is used for:
-# (synchronization=Auto, aggregation=None)
-# OnWritePolicy is used for:
-# (synchronization=Auto, aggregation=SUM,MEAN,ONLY_FIRST_REPLICA)
-# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
-# OnReadPolicy is used for:
-# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
-VARIABLE_POLICY_MAPPING = {
- vs.VariableSynchronization.AUTO: values_lib.AutoPolicy,
- vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
- vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
-}
-
-VARIABLE_CLASS_MAPPING = {
- "VariableClass": values_lib.DistributedVariable,
- vs.VariableSynchronization.AUTO: values_lib.MirroredVariable,
- vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
- vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
-}
-
-TPU_VARIABLE_POLICY_MAPPING = {
- vs.VariableSynchronization.AUTO: tpu_values_lib.TPUAutoPolicy,
- vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
- vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
-}
-
-TPU_VARIABLE_CLASS_MAPPING = {
- "VariableClass": tpu_values_lib.TPUDistributedVariable,
- vs.VariableSynchronization.AUTO: tpu_values_lib.TPUMirroredVariable,
- vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
- vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
-}
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 5323f61..b424f79 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -319,9 +319,6 @@
if ops.executing_eagerly_outside_functions():
self.experimental_enable_get_next_as_optional = True
- # Flag to turn on VariablePolicy.
- self._use_var_policy = False
-
def _initialize_strategy(self, devices):
# The _initialize_strategy method is intended to be used by distribute
# coordinator as well.
@@ -465,8 +462,7 @@
return distribute_utils.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
- distribute_utils.VARIABLE_CLASS_MAPPING,
- distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
+ values.MirroredVariable, values.SyncOnReadVariable, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils.validate_colocate_distributed_variable(
@@ -632,10 +628,10 @@
return self._cross_device_ops or self._inferred_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
- if (distribute_utils.is_mirrored(value) and
+ if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
- assert not distribute_utils.is_mirrored(value)
+ assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
@@ -690,12 +686,10 @@
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
- # pylint: disable=protected-access
- if values._is_sync_on_read(replica_local_var):
- return replica_local_var._get_cross_replica()
- assert values._is_mirrored(replica_local_var)
- return array_ops.identity(replica_local_var._get())
- # pylint: enable=protected-access
+ if isinstance(replica_local_var, values.SyncOnReadVariable):
+ return replica_local_var._get_cross_replica() # pylint: disable=protected-access
+ assert isinstance(replica_local_var, values.Mirrored)
+ return array_ops.identity(replica_local_var._get()) # pylint: disable=protected-access
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index 03d697f..8e7d674 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -379,7 +379,8 @@
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "`NONE` variable synchronization mode is not "
- "supported with "):
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.NONE)
@@ -388,7 +389,8 @@
with distribution.scope():
with self.assertRaisesRegex(
ValueError, "`NONE` variable synchronization mode is not "
- "supported with "):
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
variable_scope.variable(
1.0,
name="v",
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index bad6e6a..8e5ef06 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -544,9 +544,6 @@
context.async_wait()
atexit.register(async_wait)
- # Flag to turn on VariablePolicy
- self._use_var_policy = False
-
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils. validate_colocate(colocate_with_variable, self)
@@ -873,8 +870,8 @@
return distribute_utils.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
- distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
- distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
+ tpu_values.TPUMirroredVariable, tpu_values.TPUSyncOnReadVariable,
+ **kwargs)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.DistributedValues) or
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index ce6d2e7..3388553 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -197,58 +197,10 @@
return None
-class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
- """DistributedVariable subclass for TPUStrategy."""
-
- def _is_mirrored(self):
- self._policy._is_mirrored() # pylint: disable=protected-access
-
- def assign_sub(self, value, use_locking=False, name=None, read_value=True):
- return self._policy.assign_sub(
- self, value, use_locking=use_locking, name=name, read_value=read_value)
-
- def assign_add(self, value, use_locking=False, name=None, read_value=True):
- return self._policy.assign_add(
- self, value, use_locking=use_locking, name=name, read_value=read_value)
-
- def assign(self, value, use_locking=False, name=None, read_value=True):
- return self._policy.assign(
- self, value, use_locking=use_locking, name=name, read_value=read_value)
-
- def scatter_sub(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_sub(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_add(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_add(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_mul(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_mul(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_div(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_div(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_min(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_min(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_max(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_max(
- self, sparse_delta, use_locking=use_locking, name=name)
-
- def scatter_update(self, sparse_delta, use_locking=False, name=None):
- return self._policy.scatter_update(
- self, sparse_delta, use_locking=use_locking, name=name)
-
-
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync."""
- def assign_sub(self, value, use_locking=False, name=None,
- read_value=True):
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn(
@@ -258,11 +210,17 @@
use_locking=use_locking,
name=name,
read_value=read_value)
- return assign_sub(self, value, use_locking=use_locking, name=name,
- read_value=read_value)
- def assign_add(self, value, use_locking=False, name=None,
- read_value=True):
+ assign_sub_fn = _make_raw_assign_fn(
+ gen_resource_variable_ops.assign_sub_variable_op)
+ return self._update(
+ update_fn=assign_sub_fn,
+ value=value,
+ use_locking=use_locking,
+ name=name,
+ read_value=read_value)
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn(
@@ -272,21 +230,34 @@
use_locking=use_locking,
name=name,
read_value=read_value)
- return assign_add(self, value, use_locking=use_locking, name=name,
- read_value=read_value)
+
+ assign_add_fn = _make_raw_assign_fn(
+ gen_resource_variable_ops.assign_add_variable_op)
+ return self._update(
+ update_fn=assign_add_fn,
+ value=value,
+ use_locking=use_locking,
+ name=name,
+ read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_variable_op)(
- self,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
- return assign(self, value, use_locking=use_locking, name=name,
- read_value=read_value)
+ return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
+ self,
+ value=value,
+ use_locking=use_locking,
+ name=name,
+ read_value=read_value)
+
+ assign_fn = _make_raw_assign_fn(
+ gen_resource_variable_ops.assign_variable_op)
+ return self._update(
+ update_fn=assign_fn,
+ value=value,
+ use_locking=use_locking,
+ name=name,
+ read_value=read_value)
def scatter_sub(self, *args, **kwargs):
raise NotImplementedError
@@ -341,220 +312,3 @@
def _is_mirrored(self):
return False
-
-
-# Common method between AutoPolicy, OnWrite and Mirrored variables.
-def assign_sub(var, value, use_locking=False, name=None, read_value=True):
- assign_sub_fn = _make_raw_assign_fn(
- gen_resource_variable_ops.assign_sub_variable_op)
- return var._update( # pylint: disable=protected-access
- update_fn=assign_sub_fn,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
-
-
-def assign_add(var, value, use_locking=False, name=None, read_value=True):
- assign_add_fn = _make_raw_assign_fn(
- gen_resource_variable_ops.assign_add_variable_op)
- return var._update( # pylint: disable=protected-access
- update_fn=assign_add_fn,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
-
-
-def assign(var, value, use_locking=False, name=None, read_value=True):
- assign_fn = _make_raw_assign_fn(
- gen_resource_variable_ops.assign_variable_op)
- return var._update( # pylint: disable=protected-access
- update_fn=assign_fn,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
-
-
-class TPUAutoPolicy(values.AutoPolicy):
- """Policy defined for `tf.VariableSynchronization.AUTO` synchronization.
-
- This policy is created when `synchronization` is set to
- `tf.VariableSynchronization.AUTO` and `aggregation` is set to
- `tf.VariableAggregation.NONE` when creating a `tf.Variable` in `tf.distribute`
- scope.
- """
-
- def assign_sub(self, var, value, use_locking=False, name=None,
- read_value=True):
- if enclosing_tpu_context():
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_sub_variable_op)(
- var,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
- return assign_sub(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def assign_add(self, var, value, use_locking=False, name=None,
- read_value=True):
- if enclosing_tpu_context():
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_add_variable_op)(
- var,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
- return assign_add(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def assign(self, var, value, use_locking=False, name=None, read_value=True):
- if enclosing_tpu_context():
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_variable_op)(
- var,
- value=value,
- use_locking=use_locking,
- name=name,
- read_value=read_value)
- return assign(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def scatter_sub(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_add(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_max(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_min(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_mul(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_div(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_update(self, *args, **kwargs):
- raise NotImplementedError
-
- def _is_mirrored(self):
- return True
-
-
-class TPUOnWritePolicy(values.OnWritePolicy):
- """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
-
- This policy is created when the following `synchronization` and
- `aggregation` parameters are specified when creating a `tf.Variable` in
- `tf.distribute` scope:
- * `synchronization` is equal to `tf.VariableSynchronization.AUTO` and
- aggregation can be any of the following `tf.VariableAggregation` enum
- values such as `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
- * `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` and
- aggregation can be any of the following `tf.VariableAggregation` enum
- values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
- """
-
- def assign_sub(self, var, value, use_locking=False, name=None,
- read_value=True):
- return assign_sub(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def assign_add(self, var, value, use_locking=False, name=None,
- read_value=True):
- return assign_add(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def assign(self, var, value, use_locking=False, name=None, read_value=True):
- return assign(var, value, use_locking=use_locking, name=name,
- read_value=read_value)
-
- def scatter_sub(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_add(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_max(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_min(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_mul(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_div(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_update(self, *args, **kwargs):
- raise NotImplementedError
-
- def _is_mirrored(self):
- return True
-
-
-class TPUOnReadPolicy(values.OnReadPolicy):
- """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
-
- This policy is created when `synchronization` is set to
- `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
- values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
- `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
- scope.
- """
-
- def assign_sub(self, var, *args, **kwargs):
- if enclosing_tpu_context() is None:
- return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
- else:
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
- **kwargs)
-
- def assign_add(self, var, *args, **kwargs):
- if enclosing_tpu_context() is None:
- return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
- else:
- return _make_raw_assign_fn(
- gen_resource_variable_ops.assign_add_variable_op)(var, *args,
- **kwargs)
-
- def assign(self, var, *args, **kwargs):
- if enclosing_tpu_context() is None:
- return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
- else:
- return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
- var, *args, **kwargs)
-
- def _is_mirrored(self):
- return False
-
- def scatter_sub(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_add(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_max(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_min(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_mul(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_div(self, *args, **kwargs):
- raise NotImplementedError
-
- def scatter_update(self, *args, **kwargs):
- raise NotImplementedError
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 7dedbee..50cd8d7 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -700,49 +700,49 @@
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_sub(
+ self._policy.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_sub(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_add(
+ self._policy.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_add(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_mul(
+ self._policy.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_mul(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_div(
+ self._policy.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_div(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_min(
+ self._policy.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_min(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_max(
+ self._policy.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_max(
self, sparse_delta, use_locking=use_locking, name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if self._policy:
- return self._policy.scatter_update(
+ self._policy.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index e445c11..1c09073 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import copy
+import itertools
import os
from absl.testing import parameterized
@@ -29,12 +30,14 @@
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util as ds_test_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as values_lib
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@@ -48,56 +51,19 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
+from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.types import core
from tensorflow.python.util import nest
-def _device_str(d):
- return "/device:GPU:" + str(d)
-
-
-def _nested_value(d):
- return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
-
-
-def _make_mirrored_val(init_val=5.0):
- v = []
- devices = ["/device:GPU:0", "/device:CPU:0"]
- for d, _ in zip(devices, ["v", "v/replica"]):
- with ops.device(d):
- v.append(constant_op.constant(init_val))
- return values_lib.Mirrored(v)
-
-
-def _make_mirrored():
- v = []
- devices = ["/device:GPU:0", "/device:CPU:0"]
- for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
- with ops.device(d):
- v.append(variable_scope.get_variable(
- name=n, initializer=init, use_resource=True))
- mirrored = values_lib.MirroredVariable(
- None, v, variable_scope.VariableAggregation.SUM)
- return mirrored
-
-
-def mirrored_and_tpu_strategy_combinations():
- return combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- strategy_combinations.tpu_strategy,
- strategy_combinations.tpu_strategy_packed_var,
- ],
- mode=["graph", "eager"])
-
-
class DistributedValuesTest(test.TestCase, parameterized.TestCase):
def testGetEager(self):
@@ -397,6 +363,45 @@
self.assertEqual(v.x, v_deep_copy.x)
+def _device_str(d):
+ return "/device:GPU:" + str(d)
+
+
+def _nested_value(d):
+ return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
+
+
+def _make_mirrored_val(init_val=5.0):
+ v = []
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ for d, _ in zip(devices, ["v", "v/replica"]):
+ with ops.device(d):
+ v.append(constant_op.constant(init_val))
+ return values_lib.Mirrored(v)
+
+
+def _make_mirrored():
+ v = []
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
+ with ops.device(d):
+ v.append(variable_scope.get_variable(
+ name=n, initializer=init, use_resource=True))
+ mirrored = values_lib.MirroredVariable(
+ None, v, variable_scope.VariableAggregation.SUM)
+ return mirrored
+
+
+def mirrored_and_tpu_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["graph", "eager"])
+
+
@combinations.generate(
combinations.combine(
distribution=[
@@ -791,6 +796,507 @@
save_path = self._save_normal()
self._restore_mirrored(save_path)
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_one_gpu,
+ ],
+ mode=["graph"]))
+ def testFetchAMirroredVariable(self, distribution):
+ with self.session(graph=ops.Graph()) as sess, distribution.scope():
+ with ops.device("/device:GPU:0"):
+ v = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+ mirrored = values_lib.MirroredVariable(
+ distribution, (v,), variable_scope.VariableAggregation.MEAN)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run({"complicated": mirrored})
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["eager"]))
+ def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(1.0, name="foo")
+
+ @def_function.function
+ def mytest():
+ def model_fn():
+ v.assign(5.0)
+ return v.read_value()
+
+ return distribution.run(model_fn)
+
+ mytest()
+ self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["graph", "eager"]))
+ def testValueInReplicaContext(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ 1., aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def f():
+ with ops.control_dependencies([v.assign_add(1.)]):
+ return v.value()
+
+ results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(f)))
+ for value in results:
+ self.assertEqual(2., value)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["graph", "eager"]))
+ def testAssignOutOfScope(self, distribution):
+ with distribution.scope():
+ mirrored = variables_lib.Variable(1.)
+ self.evaluate(mirrored.assign(3.))
+ self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
+ for component in mirrored.values:
+ self.assertEqual(self.evaluate(component.read_value()), 3.)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testAssignAggregationMeanDTypeNonFloat(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ 1,
+ aggregation=variable_scope.VariableAggregation.MEAN,
+ dtype=dtypes.int32)
+ self.evaluate(v.initializer)
+
+ @def_function.function
+ def assign():
+ ctx = distribution_strategy_context.get_replica_context()
+ return v.assign(ctx.replica_id_in_sync_group)
+
+ # disallow assign() with distributed value in replica context.
+ with self.assertRaisesRegex(ValueError,
+ "Cannot update non-float variables"):
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(assign)))
+
+ # allow assign() with same value in replica context.
+ @def_function.function
+ def assign_same():
+ return v.assign(2)
+
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(assign_same)))
+ self.assertEqual(self.evaluate(v.read_value()), 2)
+
+ # allow assign() with mirrored variable in replica context.
+ with distribution.scope():
+ v2 = variables_lib.Variable(
+ 3,
+ aggregation=variable_scope.VariableAggregation.SUM,
+ dtype=dtypes.int32)
+ self.evaluate(v2.initializer)
+
+ @def_function.function
+ def assign_mirrored():
+ return v.assign(v2)
+
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(assign_mirrored)))
+ self.assertEqual(self.evaluate(v.read_value()), 3)
+
+ # allow assign() in cross replica context.
+ with distribution.scope():
+ self.evaluate(v.assign(4))
+ self.assertEqual(self.evaluate(v.read_value()), 4)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["eager"]))
+ def testInitializedToSameValueInsideEagerRun(self, distribution):
+ v = [None]
+
+ @def_function.function
+ def step():
+
+ def f():
+ if v[0] is None:
+ v[0] = variables_lib.Variable(random_ops.random_normal([]))
+
+ distribution.run(f)
+
+ context.set_global_seed(None)
+ step()
+ vals = self.evaluate(v[0].values)
+ self.assertAllEqual(vals[0], vals[1])
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["graph", "eager"]))
+ def testAggregationOnlyFirstReplica(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 15.,
+ synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+ aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def assign():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return v.assign(math_ops.cast(replica_id, dtypes.float32))
+ per_replica_results = self.evaluate(distribution.experimental_local_results(
+ distribution.run(assign)))
+ # The per-replica values should always match the first replicas value.
+ self.assertAllEqual(
+ array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
+ per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.tpu_strategy_packed_var,
+ ],
+ mode=["eager"]))
+ def testInitScope(self, distribution):
+
+ class C(object):
+ pass
+
+ obj = C()
+ obj.w = None
+ obj.v = None
+
+ @def_function.function
+ def assign():
+ with ops.init_scope():
+ if obj.w is None:
+ obj.w = variables_lib.Variable(
+ 0, aggregation=variables_lib.VariableAggregation.MEAN)
+ obj.v = variables_lib.Variable(
+ obj.w.read_value(),
+ aggregation=variables_lib.VariableAggregation.MEAN)
+
+ return obj.v.assign_add(2)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(assign)))
+ self.assertAllEqual([2, 2], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ ],
+ mode=["eager"]))
+ def testOperatorOverride(self, distribution):
+
+ with distribution.scope():
+ v = variable_scope.variable(
+ 1, aggregation=variables_lib.VariableAggregation.MEAN)
+
+ self.assertEqual(2, self.evaluate(v + 1))
+
+ @def_function.function
+ def add():
+ return v + 1
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(add)))
+ self.assertAllEqual([2, 2], per_replica_results)
+
+ @combinations.generate(mirrored_and_tpu_strategy_combinations())
+ def testAssignAdd(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 1, aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def assign():
+ return v.assign_add(2)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(assign)))
+ # The per-replica values should always match the first replicas value.
+ self.assertAllEqual([3, 3], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterSub(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(v.initializer)
+
+ @def_function.function
+ def scatter_sub():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.stack([
+ math_ops.cast(replica_id, dtypes.float32),
+ math_ops.cast(replica_id + 1, dtypes.float32)
+ ]),
+ indices=array_ops.stack([replica_id, replica_id + 1]),
+ dense_shape=(3,))
+ return v.scatter_sub(value)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_sub)))
+ self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterAdd(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+ self.evaluate(v.initializer)
+
+ @def_function.function
+ def scatter_add():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.stack([replica_id, replica_id + 1]),
+ indices=array_ops.stack([replica_id, replica_id + 1]),
+ dense_shape=(3,))
+ return v.scatter_add(value)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_add)))
+ self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterDiv(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
+ self.evaluate(v.initializer)
+
+ @def_function.function
+ def scatter_div():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.reshape(replica_id + 2, [1]),
+ indices=array_ops.reshape(replica_id, [1]),
+ dense_shape=(3,))
+ return v.scatter_div(value)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_div)))
+ self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterMul(self, distribution):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(v.initializer)
+
+ @def_function.function
+ def scatter_mul():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.reshape(
+ math_ops.cast(replica_id + 2, dtypes.float32), [1]),
+ indices=array_ops.reshape(replica_id, [1]),
+ dense_shape=(3,))
+ return v.scatter_mul(value)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_mul)))
+ self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterMin(self, distribution):
+ with distribution.scope():
+ v1 = variables_lib.Variable(
+ [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
+ v2 = variables_lib.Variable(
+ [0, 2, 0],
+ aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def scatter_min(v):
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.identity([1]),
+ indices=array_ops.identity([1]),
+ dense_shape=(3,))
+ return v.scatter_min(value)
+
+ with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_min, args=(v1,))))
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_min, args=(v2,))))
+ self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterMax(self, distribution):
+ with distribution.scope():
+ v1 = variables_lib.Variable(
+ [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+ v2 = variables_lib.Variable(
+ [0, 0, 0],
+ aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def scatter_max(v):
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.identity([1]),
+ indices=array_ops.identity([0]),
+ dense_shape=(3,))
+ return v.scatter_max(value)
+
+ with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_max, args=(v1,))))
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_max, args=(v2,))))
+ self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterUpdate(self, distribution):
+ with distribution.scope():
+ v1 = variables_lib.Variable(
+ [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+ v2 = variables_lib.Variable(
+ [0, 0, 0],
+ aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def scatter_update(v):
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.identity([3]),
+ indices=array_ops.identity([1]),
+ dense_shape=(3,))
+ return v.scatter_update(value)
+
+ with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
+ self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_update, args=(v1,))))
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(scatter_update, args=(v2,))))
+ self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode=["graph", "eager"]))
+ def testScatterOpsInCrossReplicaContext(self, distribution):
+ with distribution.scope():
+ v1 = variables_lib.Variable(
+ [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
+ v2 = variables_lib.Variable([1, 1, 1])
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ value = indexed_slices.IndexedSlices(
+ values=array_ops.identity([2]),
+ indices=array_ops.identity([0]),
+ dense_shape=(3,))
+ with distribution.scope():
+ self.evaluate(v1.scatter_add(value))
+ self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
+
+ self.evaluate(v2.scatter_min(value))
+ self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
+
_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
@@ -815,6 +1321,38 @@
return v, replica_local
+class SyncOnReadVariablePropertiesTest(test.TestCase):
+
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testProperties(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+ v, replica_local = _make_replica_local(
+ variable_scope.VariableAggregation.SUM)
+
+ self.assertEqual(v[0].constraint, replica_local.constraint)
+ self.assertEqual(v[0].name, replica_local.name)
+ self.assertEqual(v[0].dtype, replica_local.dtype)
+ self.assertEqual(v[0].shape, replica_local.shape)
+ self.assertEqual(variable_scope.VariableAggregation.SUM,
+ replica_local.aggregation)
+
+ @test_util.run_v2_only
+ def testCanPassToDefFun(self):
+ @def_function.function
+ def add1(x):
+ return x + 1
+
+ v = variable_scope.get_variable(
+ name="v", initializer=[1.], use_resource=True)
+ replica_local = values_lib.SyncOnReadVariable(
+ None, (v,), variable_scope.VariableAggregation.MEAN)
+ self.assertEqual(2., self.evaluate(add1(replica_local)))
+
+
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
# tests.
def strategy_and_run_tf_function_combinations():
@@ -851,35 +1389,6 @@
save_path, _ = self._save_return_saver(sess, var)
return save_path
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = True
-
- @test_util.run_in_graph_and_eager_modes(config=config)
- def testProperties(self):
- if context.num_gpus() < 1 and context.executing_eagerly():
- self.skipTest("A GPU is not available for this test in eager mode.")
- v, replica_local = _make_replica_local(
- variable_scope.VariableAggregation.SUM)
-
- self.assertEqual(v[0].constraint, replica_local.constraint)
- self.assertEqual(v[0].name, replica_local.name)
- self.assertEqual(v[0].dtype, replica_local.dtype)
- self.assertEqual(v[0].shape, replica_local.shape)
- self.assertEqual(variable_scope.VariableAggregation.SUM,
- replica_local.aggregation)
-
- @test_util.run_v2_only
- def testCanPassToDefFun(self):
- @def_function.function
- def add1(x):
- return x + 1
-
- v = variable_scope.get_variable(
- name="v", initializer=[1.], use_resource=True)
- replica_local = values_lib.SyncOnReadVariable(
- None, (v,), variable_scope.VariableAggregation.MEAN)
- self.assertEqual(2., self.evaluate(add1(replica_local)))
-
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testTensorConversion(self, distribution):
with context.graph_mode():
@@ -1076,6 +1585,453 @@
save_path = self._save_normal()
self._restore_replica_local_sum(save_path, distribution)
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testAssign(self, distribution, experimental_run_tf_function):
+
+ def assign(fn, v, update_value, cross_replica):
+ update_fn = lambda: getattr(v, fn)(update_value)
+ if cross_replica:
+ return update_fn()
+ else:
+ if experimental_run_tf_function:
+ update_fn = def_function.function(update_fn)
+ return distribution.experimental_local_results(
+ distribution.run(update_fn))
+
+ updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
+ aggregations = [
+ variables_lib.VariableAggregation.NONE,
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ options = list(
+ x for x in itertools.product(updates, aggregations, [True, False]))
+ for update, aggregation, cross_replica in options:
+ # VariableAggregation.SUM in cross-replica mode is tested below,
+ # VariableAggregation.NONE in cross-replica mode is not supported.
+ if cross_replica and aggregation in [
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.NONE,
+ ]:
+ continue
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+ fn, update_value = update
+ self.evaluate(assign(fn, v, update_value, cross_replica))
+ for component in v._values:
+ self.assertAllEqual(self.evaluate(component.read_value()),
+ self.evaluate(array_ops.ones_like(component)))
+
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testAssignDtypeConversion(self, distribution,
+ experimental_run_tf_function):
+
+ def assign(fn, v, update_value, cross_replica):
+ update_fn = lambda: getattr(v, fn)(update_value)
+ if cross_replica:
+ return update_fn()
+ else:
+ if experimental_run_tf_function:
+ update_fn = def_function.function(update_fn)
+ return distribution.experimental_local_results(
+ distribution.run(update_fn))
+
+ updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
+ aggregations = [
+ variables_lib.VariableAggregation.NONE,
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ options = list(
+ x for x in itertools.product(updates, aggregations, [True, False]))
+ for update, aggregation, cross_replica in options:
+ # VariableAggregation.SUM in cross-replica mode is tested below,
+ # VariableAggregation.NONE in cross-replica mode is not supported.
+ if cross_replica and aggregation in [
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.NONE,
+ ]:
+ continue
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+ fn, update_value = update
+ self.evaluate(assign(fn, v, update_value, cross_replica))
+ for component in v._values:
+ self.assertAllEqual(self.evaluate(component.read_value()),
+ self.evaluate(array_ops.ones_like(component)))
+
+ @combinations.generate(mirrored_and_tpu_strategy_combinations())
+ def testAssignWithAggregationSum(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=variables_lib.VariableAggregation.SUM)
+ self.evaluate(variables_lib.global_variables_initializer())
+ self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
+ for component in v._values:
+ self.assertAllEqual(self.evaluate(component.read_value()),
+ self.evaluate(array_ops.ones_like(component)))
+
+ @combinations.generate(mirrored_and_tpu_strategy_combinations())
+ def testAssignAddSubWithAggregationSum(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=variables_lib.VariableAggregation.SUM)
+ self.evaluate(variables_lib.global_variables_initializer())
+ with self.assertRaisesRegex(
+ ValueError, "SyncOnReadVariable does not support "):
+ self.evaluate(v.assign_add(1.))
+ with self.assertRaisesRegex(
+ ValueError, "SyncOnReadVariable does not support "):
+ self.evaluate(v.assign_sub(1.))
+
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testReadValueInReplicaContext(self, distribution,
+ experimental_run_tf_function):
+ aggregations = [
+ variables_lib.VariableAggregation.NONE,
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ for aggregation in aggregations:
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+ if experimental_run_tf_function:
+ read_var_fn = def_function.function(v.read_value)
+ else:
+ read_var_fn = v.read_value
+ results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.run(read_var_fn)))
+ for component, value in zip(v._values, results):
+ self.assertAllEqual(self.evaluate(component.read_value()), value)
+
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testReadValueInCrossReplicaContext(self, distribution,
+ experimental_run_tf_function):
+ aggregations = [
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ for aggregation in aggregations:
+ if isinstance(distribution, _TPU_STRATEGIES):
+ resolver = tpu_cluster_resolver.TPUClusterResolver("")
+ tpu_strategy_util.initialize_tpu_system(resolver)
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ def assign(v=v):
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return v.assign(math_ops.cast(replica_id, dtypes.float32))
+
+ if experimental_run_tf_function:
+ assign = def_function.function(assign)
+
+ self.evaluate(
+ distribution.experimental_local_results(distribution.run(assign)))
+ num_replicas = distribution.num_replicas_in_sync
+ sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
+ if aggregation == variables_lib.VariableAggregation.SUM:
+ expected = sum_of_replica_values
+ elif aggregation == variables_lib.VariableAggregation.MEAN:
+ expected = sum_of_replica_values / num_replicas
+ else:
+ expected = 0
+ self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
+ self.assertEqual(expected, self.evaluate(v.value()), aggregation)
+ self.assertEqual(expected, self.evaluate(v), aggregation)
+ self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
+ aggregation)
+
+ # TODO(b/145574622): Re-enable this test once ReduceOp argument is
+ # respected on GPUs.
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def disable_testAllReduce(self, distribution,
+ experimental_run_tf_function):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 2.,
+ synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+ aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ def all_reduce():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
+ dtypes.float32)
+
+ if experimental_run_tf_function:
+ all_reduce = def_function.function(all_reduce)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(all_reduce)))
+ expected_result = []
+ for i in range(distribution.num_replicas_in_sync):
+ expected_result.append(2.0 * distribution.num_replicas_in_sync +
+ 1.0 * i)
+ self.assertEqual(per_replica_results, tuple(expected_result))
+
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testAssignPerReplicaBeforeRead(self, distribution,
+ experimental_run_tf_function):
+ aggregations = [
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ for aggregation in aggregations:
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ def assign(var=v):
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return var.assign(math_ops.cast(replica_id, dtypes.float32))
+
+ if experimental_run_tf_function:
+ assign = def_function.function(assign)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(assign)))
+ expected_result = []
+ for i in range(distribution.num_replicas_in_sync):
+ expected_result.append(1.0 * i)
+ self.assertEqual(per_replica_results, tuple(expected_result))
+
+ @combinations.generate(mirrored_and_tpu_strategy_combinations())
+ def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=variables_lib.VariableAggregation.NONE)
+ self.evaluate(variables_lib.global_variables_initializer())
+ with self.assertRaisesRegex(
+ ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
+ self.evaluate(v.read_value())
+
+ @combinations.generate(mirrored_and_tpu_strategy_combinations())
+ def testInitializedToSameValueInsideEagerRun(self, distribution):
+ if not context.executing_eagerly(): self.skipTest("eager only")
+
+ v = [None]
+ @def_function.function
+ def step():
+ def f():
+ if v[0] is None:
+ v[0] = variables_lib.Variable(
+ random_ops.random_normal([]),
+ synchronization=variables_lib.VariableSynchronization.ON_READ)
+
+ distribution.run(f)
+
+ context.set_global_seed(None)
+ step()
+ vals = self.evaluate(v[0].values)
+ self.assertAllEqual(vals[0], vals[1])
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.tpu_strategy,
+ ],
+ mode=["eager"]))
+ def testOperatorOverride(self, distribution):
+
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.0,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=variables_lib.VariableAggregation.MEAN)
+
+ @def_function.function
+ def assign():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return v.assign(math_ops.cast(replica_id, dtypes.float32))
+
+ # Assign different replicas with different values.
+ distribution.run(assign)
+
+ self.assertEqual(1.5, self.evaluate(v + 1))
+
+ @def_function.function
+ def add():
+ return v + 1
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(distribution.run(add)))
+ self.assertAllEqual([1, 2], per_replica_results)
+
+
+@combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ aggregation=[
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ],
+ mode=["graph", "eager"]))
+class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
+
+ def testScatterSub(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [1., 1., 1.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
+
+ def testScatterAdd(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [1., 1., 1.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
+
+ def testScatterDiv(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [2., 6., 1.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
+
+ def testScatterMul(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [2., 1., 1.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
+
+ def testScatterMin(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [3., 4., 5.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
+
+ def testScatterMax(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [3., 4., 5.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
+
+ def testScatterUpdate(self, distribution, aggregation):
+ with distribution.scope():
+ v = variables_lib.Variable(
+ [0., 0., 0.],
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(v.initializer)
+
+ delta = values_lib.PerReplica([
+ indexed_slices.IndexedSlices(
+ values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
+ indexed_slices.IndexedSlices(
+ values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
+ ])
+
+ with self.assertRaises(NotImplementedError):
+ self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
+
class MirroredTest(test.TestCase):
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
deleted file mode 100644
index 5866c0c..0000000
--- a/tensorflow/python/distribute/vars_test.py
+++ /dev/null
@@ -1,1270 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the distributed values library."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-from absl.testing import parameterized
-
-from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.distribute import strategy_combinations
-from tensorflow.python.distribute import tpu_strategy
-from tensorflow.python.distribute import tpu_values
-from tensorflow.python.distribute import values
-from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
-from tensorflow.python.eager import context
-from tensorflow.python.eager import def_function
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import indexed_slices
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.tpu import tpu_strategy_util
-
-
-_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
-
-
-def strategy_and_run_tf_function_combinations():
- # Test the combination of different strategies and whether a tf.function
- # is passed into strategy.run."""
- return combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- ],
- mode=["graph", "eager"],
- experimental_run_tf_function=[True, False],
- use_var_policy=[True, False]) + combinations.combine(
- distribution=[
- strategy_combinations.tpu_strategy,
- strategy_combinations.tpu_strategy_packed_var,
- ],
- mode=["graph", "eager"],
- experimental_run_tf_function=[True],
- use_var_policy=[True, False])
-
-
-def strategy_with_var_policy():
- return combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- strategy_combinations.tpu_strategy,
- strategy_combinations.tpu_strategy_packed_var,
- strategy_combinations.central_storage_strategy_with_two_gpus,
- ],
- mode=["graph", "eager"],
- use_var_policy=[True, False])
-
-
-class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
-
- @combinations.generate(
- combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_one_gpu,
- ],
- mode=["graph"]))
- def testFetchAMirroredVariable(self, distribution):
- with self.session(graph=ops.Graph()) as sess, distribution.scope():
- with ops.device("/device:GPU:0"):
- v = variable_scope.get_variable(
- name="v", initializer=1., use_resource=True)
- mirrored = values.MirroredVariable(
- distribution, (v,), variable_scope.VariableAggregation.MEAN)
- sess.run(variables_lib.global_variables_initializer())
- sess.run({"complicated": mirrored})
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssign(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # assign in replica context with SUM does not make sense cause you can
- # just do value * num replicas error is 1. is not a distributed value and
- # is unsupported for aggregation SUM
- if (not cross_replica and aggregation ==
- variables_lib.VariableAggregation.SUM):
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- for component in v._values:
- self.assertAllEqual(self.evaluate(component.read_value()),
- self.evaluate(array_ops.ones_like(component)))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignOnWriteVar(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- with distribution.scope():
- v_to_assign = variable_scope.variable(
- 2., aggregation=variables_lib.VariableAggregation.MEAN)
- v_to_assign_sub = variable_scope.variable(
- -2., aggregation=variables_lib.VariableAggregation.MEAN)
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
- ("assign_sub", v_to_assign_sub)]
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # assign in replica context with SUM does not make sense cause you can
- # just do value * num replicas error is 1. is not a distributed value and
- # is unsupported for aggregation SUM
- if aggregation == variables_lib.VariableAggregation.SUM:
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- for component in v._values:
- self.assertAllEqual(2.0, self.evaluate(component.read_value()))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- if isinstance(distribution, _TPU_STRATEGIES):
- self.skipTest("Assigning PerReplica values is not supported. See"
- " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
-
- with distribution.scope():
- per_replica_value = values.PerReplica(
- [constant_op.constant(2.0),
- constant_op.constant(2.0)])
- per_replica_sub_value = values.PerReplica(
- [constant_op.constant(-2.0),
- constant_op.constant(-2.0)])
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", per_replica_value), ("assign_add", per_replica_value),
- ("assign_sub", per_replica_sub_value)]
- # We don't support assigning PerReplica valus to vars in replica context
- # with aggregation=NONE.
- aggregations = [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # assign in replica context with SUM does not make sense cause you can
- # just do value * num replicas error is 1. is not a distributed value and
- # is unsupported for aggregation SUM
- if cross_replica:
- # We don't support assigning PerReplica values to MirroredVariables in
- # cross replica context
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- if aggregation == variables_lib.VariableAggregation.SUM:
- expected = 4.0
- else:
- expected = 2.0
- for component in v._values:
- self.assertAllEqual(expected, self.evaluate(component.read_value()))
-
- @combinations.generate(strategy_with_var_policy())
- def testValueInReplicaContext(self, distribution, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- 1., aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def f():
- with ops.control_dependencies([v.assign_add(1.)]):
- return v.value()
-
- results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(f)))
- for value in results:
- self.assertEqual(2., value)
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testReadValueInReplicaContext(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- for aggregation in aggregations:
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- if experimental_run_tf_function:
- read_var_fn = def_function.function(v.read_value)
- else:
- read_var_fn = v.read_value
- results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(read_var_fn)))
- for component, value in zip(v._values, results):
- self.assertAllEqual(self.evaluate(component.read_value()), value)
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testReadValueInCrossReplicaContext(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- for aggregation in aggregations:
- with distribution.scope():
- v = variable_scope.variable(
- 2.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
-
- if experimental_run_tf_function:
- read_var_fn = def_function.function(v.read_value)
- else:
- read_var_fn = v.read_value
-
- results = read_var_fn()
- for component in v._values:
- self.assertEqual(self.evaluate(component.read_value()),
- self.evaluate(results))
-
- @combinations.generate(strategy_with_var_policy())
- def testAssignOutOfScope(self, distribution, use_var_policy):
- with distribution.scope():
- mirrored = variables_lib.Variable(1.)
- self.evaluate(mirrored.assign(3.))
- self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
- for component in mirrored.values:
- self.assertEqual(self.evaluate(component.read_value()), 3.)
-
- @combinations.generate(strategy_with_var_policy())
- def testAssignAggregationMeanDTypeNonFloat(self, distribution,
- use_var_policy):
- if isinstance(distribution, _TPU_STRATEGIES):
- self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
- "reenabling test.")
-
- with distribution.scope():
- v = variables_lib.Variable(
- 1,
- aggregation=variable_scope.VariableAggregation.MEAN,
- dtype=dtypes.int32)
- self.evaluate(v.initializer)
-
- @def_function.function
- def assign():
- ctx = distribution_strategy_context.get_replica_context()
- return v.assign(ctx.replica_id_in_sync_group)
-
- # disallow assign() with distributed value in replica context.
- with self.assertRaisesRegex(ValueError,
- "Cannot update non-float variables"):
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(assign)))
-
- # allow assign() with same value in replica context.
- @def_function.function
- def assign_same():
- return v.assign(2)
-
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(assign_same)))
- self.assertEqual(self.evaluate(v.read_value()), 2)
-
- # allow assign() with mirrored variable in replica context.
- with distribution.scope():
- v2 = variables_lib.Variable(
- 3,
- aggregation=variable_scope.VariableAggregation.SUM,
- dtype=dtypes.int32)
- self.evaluate(v2.initializer)
-
- @def_function.function
- def assign_mirrored():
- return v.assign(v2)
-
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(assign_mirrored)))
- self.assertEqual(self.evaluate(v.read_value()), 3)
-
- # allow assign() in cross replica context.
- with distribution.scope():
- self.evaluate(v.assign(4))
- self.assertEqual(self.evaluate(v.read_value()), 4)
-
- @combinations.generate(strategy_with_var_policy())
- def testInitializedToSameValueInsideEagerRun(self, distribution,
- use_var_policy):
- if not context.executing_eagerly(): self.skipTest("eager only test")
- v = [None]
-
- @def_function.function
- def step():
-
- def f():
- if v[0] is None:
- v[0] = variables_lib.Variable(random_ops.random_normal([]))
-
- distribution.run(f)
-
- context.set_global_seed(None)
- step()
- vals = self.evaluate(v[0].values)
- self.assertAllEqual(vals[0], vals[1])
-
- @combinations.generate(strategy_with_var_policy())
- def testAggregationOnlyFirstReplica(self, distribution, use_var_policy):
- with distribution.scope():
- v = variable_scope.variable(
- 15.,
- synchronization=variables_lib.VariableSynchronization.ON_WRITE,
- aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def assign():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- return v.assign(math_ops.cast(replica_id, dtypes.float32))
- per_replica_results = self.evaluate(distribution.experimental_local_results(
- distribution.run(assign)))
- # The per-replica values should always match the first replicas value.
- self.assertAllEqual(
- array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
- per_replica_results)
-
- @combinations.generate(strategy_with_var_policy())
- def testInitScope(self, distribution, use_var_policy):
- if not context.executing_eagerly(): self.skipTest("eager only")
-
- class C(object):
- pass
-
- obj = C()
- obj.w = None
- obj.v = None
-
- @def_function.function
- def assign():
- with ops.init_scope():
- if obj.w is None:
- obj.w = variables_lib.Variable(
- 0, aggregation=variables_lib.VariableAggregation.MEAN)
- obj.v = variables_lib.Variable(
- obj.w.read_value(),
- aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(variables_lib.global_variables_initializer())
-
- return obj.v.assign_add(2)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(distribution.run(assign)))
- self.assertAllEqual([2, 2], per_replica_results)
-
- @combinations.generate(strategy_with_var_policy())
- def testOperatorOverride(self, distribution, use_var_policy):
-
- with distribution.scope():
- v = variable_scope.variable(
- 1, aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(variables_lib.global_variables_initializer())
-
- self.assertEqual(2, self.evaluate(v + 1))
-
- @def_function.function
- def add():
- return v + 1
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(distribution.run(add)))
- self.assertAllEqual([2, 2], per_replica_results)
-
-
-@combinations.generate(
- combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- ],
- mode=["graph", "eager"],
- use_var_policy=[True, False]))
-class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
-
- def testScatterSub(self, distribution, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(v.initializer)
-
- @def_function.function
- def scatter_sub():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- value = indexed_slices.IndexedSlices(
- values=array_ops.stack([
- math_ops.cast(replica_id, dtypes.float32),
- math_ops.cast(replica_id + 1, dtypes.float32)
- ]),
- indices=array_ops.stack([replica_id, replica_id + 1]),
- dense_shape=(3,))
- return v.scatter_sub(value)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_sub)))
- self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
-
- def testScatterAdd(self, distribution, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
- self.evaluate(v.initializer)
-
- @def_function.function
- def scatter_add():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- value = indexed_slices.IndexedSlices(
- values=array_ops.stack([replica_id, replica_id + 1]),
- indices=array_ops.stack([replica_id, replica_id + 1]),
- dense_shape=(3,))
- return v.scatter_add(value)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_add)))
- self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
-
- def testScatterDiv(self, distribution, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
- self.evaluate(v.initializer)
-
- @def_function.function
- def scatter_div():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- value = indexed_slices.IndexedSlices(
- values=array_ops.reshape(replica_id + 2, [1]),
- indices=array_ops.reshape(replica_id, [1]),
- dense_shape=(3,))
- return v.scatter_div(value)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_div)))
- self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
-
- def testScatterMul(self, distribution, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(v.initializer)
-
- @def_function.function
- def scatter_mul():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- value = indexed_slices.IndexedSlices(
- values=array_ops.reshape(
- math_ops.cast(replica_id + 2, dtypes.float32), [1]),
- indices=array_ops.reshape(replica_id, [1]),
- dense_shape=(3,))
- return v.scatter_mul(value)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_mul)))
- self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
-
- def testScatterMin(self, distribution, use_var_policy):
- with distribution.scope():
- v1 = variables_lib.Variable(
- [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
- v2 = variables_lib.Variable(
- [0, 2, 0],
- aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def scatter_min(v):
- value = indexed_slices.IndexedSlices(
- values=array_ops.identity([1]),
- indices=array_ops.identity([1]),
- dense_shape=(3,))
- return v.scatter_min(value)
-
- with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_min, args=(v1,))))
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_min, args=(v2,))))
- self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
-
- def testScatterMax(self, distribution, use_var_policy):
- with distribution.scope():
- v1 = variables_lib.Variable(
- [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
- v2 = variables_lib.Variable(
- [0, 0, 0],
- aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def scatter_max(v):
- value = indexed_slices.IndexedSlices(
- values=array_ops.identity([1]),
- indices=array_ops.identity([0]),
- dense_shape=(3,))
- return v.scatter_max(value)
-
- with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_max, args=(v1,))))
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_max, args=(v2,))))
- self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
-
- def testScatterUpdate(self, distribution, use_var_policy):
- with distribution.scope():
- v1 = variables_lib.Variable(
- [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
- v2 = variables_lib.Variable(
- [0, 0, 0],
- aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def scatter_update(v):
- value = indexed_slices.IndexedSlices(
- values=array_ops.identity([3]),
- indices=array_ops.identity([1]),
- dense_shape=(3,))
- return v.scatter_update(value)
-
- with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
- self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_update, args=(v1,))))
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(scatter_update, args=(v2,))))
- self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
-
- def testScatterOpsInCrossReplicaContext(self, distribution, use_var_policy):
- with distribution.scope():
- v1 = variables_lib.Variable(
- [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
- v2 = variables_lib.Variable([1, 1, 1])
- self.evaluate(variables_lib.global_variables_initializer())
-
- value = indexed_slices.IndexedSlices(
- values=array_ops.identity([2]),
- indices=array_ops.identity([0]),
- dense_shape=(3,))
- with distribution.scope():
- self.evaluate(v1.scatter_add(value))
- self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
-
- self.evaluate(v2.scatter_min(value))
- self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
-
-
-def _make_replica_local(method, strategy=None):
- if strategy is None:
- devices = ("/device:GPU:0", "/device:CPU:0")
- else:
- devices = strategy.extended.worker_devices
-
- v = []
- for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
- with ops.device(d):
- v.append(variable_scope.get_variable(
- name=n, initializer=init, use_resource=True))
-
- if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
- var_cls = tpu_values.TPUSyncOnReadVariable
- else:
- var_cls = values.SyncOnReadVariable
- replica_local = var_cls(strategy, v, method)
- return v, replica_local
-
-
-class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssign(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # VariableAggregation.SUM in cross-replica mode is tested below,
- # VariableAggregation.NONE in cross-replica mode is not supported.
- if cross_replica and aggregation in [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.NONE,
- ]:
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- for component in v._values:
- self.assertAllEqual(self.evaluate(component.read_value()),
- self.evaluate(array_ops.ones_like(component)))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignOnReadVar(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- with distribution.scope():
- v_to_assign = variable_scope.variable(
- 2., aggregation=variables_lib.VariableAggregation.MEAN)
- v_to_assign_sub = variable_scope.variable(
- -2., aggregation=variables_lib.VariableAggregation.MEAN)
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
- ("assign_sub", v_to_assign_sub)]
- expected_cross_replica = {
- variables_lib.VariableAggregation.SUM: 1.0,
- variables_lib.VariableAggregation.MEAN: 2.0,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
- }
- expected_replica = {
- variables_lib.VariableAggregation.SUM: 2.0,
- variables_lib.VariableAggregation.MEAN: 2.0,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
- }
- # aggregation=NONE is not supported for OnReadVariables.
- aggregations = [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # assign in replica context with SUM does not make sense cause you can
- # just do value * num replicas error is 1. is not a distributed value and
- # is unsupported for aggregation SUM
- if aggregation == variables_lib.VariableAggregation.SUM:
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- if cross_replica:
- for component in v._values:
- self.assertAllEqual(expected_cross_replica.get(aggregation),
- self.evaluate(component.read_value()))
- else:
- for component in v._values:
- self.assertAllEqual(expected_replica.get(aggregation),
- self.evaluate(component.read_value()))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
- use_var_policy):
-
- if isinstance(distribution, _TPU_STRATEGIES):
- self.skipTest("Assigning PerReplica values is not supported. See"
- " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
-
- self.skipTest("We don't support assiging PerReplica values in cross "
- "replica context or replica context. see error in "
- "sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.")
-
- with distribution.scope():
- per_replica_value = values.PerReplica(
- [constant_op.constant(2.0),
- constant_op.constant(2.0)])
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", per_replica_value)]
- # We don't support assigning PerReplica valus to vars in replica context
- # with aggregation=NONE.
- aggregations = [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # assign in replica context with SUM does not make sense cause you can
- # just do value * num replicas error is 1. is not a distributed value and
- # is unsupported for aggregation SUM
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- # with self.assertRaisesRegex(ValueError, "Attempt to convert a value "):
- self.evaluate(assign(fn, v, update_value, cross_replica))
- if aggregation == variables_lib.VariableAggregation.SUM:
- expected = 4.0
- else:
- expected = 2.0
- for component in v._values:
- self.assertAllEqual(expected, self.evaluate(component.read_value()))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignDtypeConversion(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
-
- def assign(fn, v, update_value, cross_replica):
- update_fn = lambda: getattr(v, fn)(update_value)
- if cross_replica:
- return update_fn()
- else:
- if experimental_run_tf_function:
- update_fn = def_function.function(update_fn)
- return distribution.experimental_local_results(
- distribution.run(update_fn))
-
- updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- options = list(
- x for x in itertools.product(updates, aggregations, [True, False]))
- for update, aggregation, cross_replica in options:
- # VariableAggregation.SUM in cross-replica mode is tested below,
- # VariableAggregation.NONE in cross-replica mode is not supported.
- if cross_replica and aggregation in [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.NONE,
- ]:
- continue
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- fn, update_value = update
- self.evaluate(assign(fn, v, update_value, cross_replica))
- for component in v._values:
- self.assertAllEqual(self.evaluate(component.read_value()),
- self.evaluate(array_ops.ones_like(component)))
-
- @combinations.generate(strategy_with_var_policy())
- def testAssignWithAggregationSum(self, distribution, use_var_policy):
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=variables_lib.VariableAggregation.SUM)
- self.evaluate(variables_lib.global_variables_initializer())
- self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
- for component in v._values:
- self.assertAllEqual(self.evaluate(component.read_value()),
- self.evaluate(array_ops.ones_like(component)))
-
- @combinations.generate(strategy_with_var_policy())
- def testAssignAddSubWithAggregationSum(self, distribution, use_var_policy):
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=variables_lib.VariableAggregation.SUM)
- self.evaluate(variables_lib.global_variables_initializer())
- with self.assertRaisesRegex(
- ValueError, "SyncOnReadVariable does not support "):
- self.evaluate(v.assign_add(1.))
- with self.assertRaisesRegex(
- ValueError, "SyncOnReadVariable does not support "):
- self.evaluate(v.assign_sub(1.))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testReadValueInReplicaContext(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- aggregations = [
- variables_lib.VariableAggregation.NONE,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- for aggregation in aggregations:
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
- if experimental_run_tf_function:
- read_var_fn = def_function.function(v.read_value)
- else:
- read_var_fn = v.read_value
- results = self.evaluate(
- distribution.experimental_local_results(
- distribution.run(read_var_fn)))
- for component, value in zip(v._values, results):
- self.assertAllEqual(self.evaluate(component.read_value()), value)
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testReadValueInCrossReplicaContext(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- aggregations = [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- for aggregation in aggregations:
- if isinstance(distribution, _TPU_STRATEGIES):
- resolver = tpu_cluster_resolver.TPUClusterResolver("")
- tpu_strategy_util.initialize_tpu_system(resolver)
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
-
- def assign(v=v):
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- return v.assign(math_ops.cast(replica_id, dtypes.float32))
-
- if experimental_run_tf_function:
- assign = def_function.function(assign)
-
- self.evaluate(
- distribution.experimental_local_results(distribution.run(assign)))
- num_replicas = distribution.num_replicas_in_sync
- sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
- if aggregation == variables_lib.VariableAggregation.SUM:
- expected = sum_of_replica_values
- elif aggregation == variables_lib.VariableAggregation.MEAN:
- expected = sum_of_replica_values / num_replicas
- else:
- expected = 0
- self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
- self.assertEqual(expected, self.evaluate(v.value()), aggregation)
- self.assertEqual(expected, self.evaluate(v), aggregation)
- self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
- aggregation)
-
- # TODO(b/145574622): Re-enable this test once ReduceOp argument is
- # respected on GPUs.
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def disable_testAllReduce(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- with distribution.scope():
- v = variable_scope.variable(
- 2.,
- synchronization=variables_lib.VariableSynchronization.ON_WRITE,
- aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(variables_lib.global_variables_initializer())
-
- def all_reduce():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
- dtypes.float32)
-
- if experimental_run_tf_function:
- all_reduce = def_function.function(all_reduce)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(distribution.run(all_reduce)))
- expected_result = []
- for i in range(distribution.num_replicas_in_sync):
- expected_result.append(2.0 * distribution.num_replicas_in_sync +
- 1.0 * i)
- self.assertEqual(per_replica_results, tuple(expected_result))
-
- @combinations.generate(strategy_and_run_tf_function_combinations())
- def testAssignPerReplicaBeforeRead(self, distribution,
- experimental_run_tf_function,
- use_var_policy):
- aggregations = [
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ]
- for aggregation in aggregations:
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(variables_lib.global_variables_initializer())
-
- def assign(var=v):
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- return var.assign(math_ops.cast(replica_id, dtypes.float32))
-
- if experimental_run_tf_function:
- assign = def_function.function(assign)
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(distribution.run(assign)))
- expected_result = []
- for i in range(distribution.num_replicas_in_sync):
- expected_result.append(1.0 * i)
- self.assertEqual(per_replica_results, tuple(expected_result))
-
- @combinations.generate(strategy_with_var_policy())
- def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution,
- use_var_policy):
- with distribution.scope():
- v = variable_scope.variable(
- 0.,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=variables_lib.VariableAggregation.NONE)
- self.evaluate(variables_lib.global_variables_initializer())
- with self.assertRaisesRegex(
- ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
- self.evaluate(v.read_value())
-
- @combinations.generate(strategy_with_var_policy())
- def testInitializedToSameValueInsideEagerRun(self, distribution,
- use_var_policy):
- if not context.executing_eagerly(): self.skipTest("eager only")
-
- v = [None]
- @def_function.function
- def step():
- def f():
- if v[0] is None:
- v[0] = variables_lib.Variable(
- random_ops.random_normal([]),
- synchronization=variables_lib.VariableSynchronization.ON_READ)
-
- distribution.run(f)
-
- context.set_global_seed(None)
- step()
- vals = self.evaluate(v[0].values)
- self.assertAllEqual(vals[0], vals[1])
-
- @combinations.generate(strategy_with_var_policy())
- def testOperatorOverride(self, distribution, use_var_policy):
-
- with distribution.scope():
- v = variable_scope.variable(
- 0.0,
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=variables_lib.VariableAggregation.MEAN)
- self.evaluate(variables_lib.global_variables_initializer())
-
- @def_function.function
- def assign():
- ctx = distribution_strategy_context.get_replica_context()
- replica_id = ctx.replica_id_in_sync_group
- return v.assign(math_ops.cast(replica_id, dtypes.float32))
-
- # Assign different replicas with different values.
- self.evaluate(distribution.experimental_local_results(
- distribution.run(assign)))
- self.assertEqual(1.5, self.evaluate(v + 1))
-
- @def_function.function
- def add():
- return v + 1
-
- per_replica_results = self.evaluate(
- distribution.experimental_local_results(distribution.run(add)))
- self.assertAllEqual([1, 2], per_replica_results)
-
-
-@combinations.generate(
- combinations.combine(
- distribution=[
- strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
- ],
- aggregation=[
- variables_lib.VariableAggregation.MEAN,
- variables_lib.VariableAggregation.SUM,
- variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
- ],
- mode=["graph", "eager"],
- use_var_policy=[True, False]))
-class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
-
- def testScatterSub(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [1., 1., 1.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
-
- def testScatterAdd(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [1., 1., 1.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
-
- def testScatterDiv(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [2., 6., 1.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
-
- def testScatterMul(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [2., 1., 1.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
-
- def testScatterMin(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [3., 4., 5.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
-
- def testScatterMax(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [3., 4., 5.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
-
- def testScatterUpdate(self, distribution, aggregation, use_var_policy):
- with distribution.scope():
- v = variables_lib.Variable(
- [0., 0., 0.],
- synchronization=variables_lib.VariableSynchronization.ON_READ,
- aggregation=aggregation)
- self.evaluate(v.initializer)
-
- delta = values.PerReplica([
- indexed_slices.IndexedSlices(
- values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
- indexed_slices.IndexedSlices(
- values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
- ])
-
- with self.assertRaises(NotImplementedError):
- self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
-
-
-def _make_index_slices(vals, indices, dense_shape=None):
- if dense_shape:
- dense_shape = array_ops.identity(dense_shape)
- return indexed_slices.IndexedSlices(
- array_ops.identity(vals), array_ops.identity(indices), dense_shape)
-
-
-if __name__ == "__main__":
- test.main()