Add support for variable policy to be used by MirroredStrategy and TPUStrategy. Refactor existing values test and add an option to test variable policy.

PiperOrigin-RevId: 323304491
Change-Id: I5c00791bc62a930274c254b33f4a47d671d0b7bf
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 185b456..356fb3a 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -302,7 +302,6 @@
         ":distribute_lib",
         ":reduce_util",
         ":shared_variable_creator",
-        ":tpu_values",
         ":values",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:config",
@@ -1226,46 +1225,6 @@
 )
 
 distribute_py_test(
-    name = "vars_test",
-    size = "medium",
-    srcs = ["vars_test.py"],
-    main = "vars_test.py",
-    shard_count = 5,
-    tags = [
-        "multi_and_single_gpu",
-        "no_rocm",
-    ],
-    tpu_tags = [
-        "no_oss",  # b/150954621 Target too big to run serially reliably.
-    ],
-    deps = [
-        ":combinations",
-        ":distribute_lib",
-        ":strategy_combinations",
-        ":tpu_strategy",
-        ":tpu_values",
-        ":values",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:checkpoint_management",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:indexed_slices",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:random_ops",
-        "//tensorflow/python:training",
-        "//tensorflow/python:variable_scope",
-        "//tensorflow/python:variables",
-        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
-        "//tensorflow/python/eager:context",
-        "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/eager:test",
-        "//tensorflow/python/tpu:tpu_lib",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
-distribute_py_test(
     name = "ps_values_test",
     size = "medium",
     srcs = ["ps_values_test.py"],
diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py
index a86c751..ad8bb87 100644
--- a/tensorflow/python/distribute/combinations.py
+++ b/tensorflow/python/distribute/combinations.py
@@ -58,17 +58,11 @@
   """
 
   def modified_arguments(self, kwargs, requested_parameters):
-    # Get the parameter that indicates if we need to set the `_use_policy` flag
-    # on the strategy object. This is a temporary flag for testing the variable
-    # policy rollout.
-    use_var_policy = kwargs.get("use_var_policy", None)
+    del requested_parameters
     distribution_arguments = {}
     for k, v in kwargs.items():
       if isinstance(v, NamedDistribution):
-        strategy = v.strategy
-        if use_var_policy:
-          strategy.extended._use_var_policy = use_var_policy
-        distribution_arguments[k] = strategy
+        distribution_arguments[k] = v.strategy
     return distribution_arguments
 
 
diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py
index 916ebaf..89848b9 100644
--- a/tensorflow/python/distribute/distribute_utils.py
+++ b/tensorflow/python/distribute/distribute_utils.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.distribute import tpu_values as tpu_values_lib
 from tensorflow.python.distribute import values as values_lib
 from tensorflow.python.eager import context
 from tensorflow.python.eager import tape
@@ -146,7 +145,7 @@
 
   def _get_mirrored(x):
     if isinstance(x, values_lib.DistributedValues):
-      if not is_mirrored(x):
+      if not isinstance(x, values_lib.Mirrored):
         raise TypeError(
             "Expected value to be mirrored across replicas: %s in %s." %
             (x, structured))
@@ -246,25 +245,34 @@
 
 
 # Variable creation function for sync strategies.
-def _get_and_validate_synchronization(kwargs):
-  """Validate that given synchronization value is valid."""
+def create_mirrored_variable(  # pylint: disable=missing-docstring
+    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
+  # Figure out what collections this variable should be added to.
+  # We'll add the MirroredVariable to those collections instead.
+  var_collections = kwargs.pop("collections", None)
+  if var_collections is None:
+    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+  kwargs["collections"] = []
+
   synchronization = kwargs.get("synchronization",
-                               vs.VariableSynchronization.AUTO)
+                               vs.VariableSynchronization.ON_WRITE)
+
   if synchronization == vs.VariableSynchronization.NONE:
     raise ValueError(
-        "`NONE` variable synchronization mode is not supported with "
-        "tf.distribute strategy. Please change the `synchronization` for "
+        "`NONE` variable synchronization mode is not supported with `Mirrored` "
+        "distribution strategy. Please change the `synchronization` for "
         "variable: " + str(kwargs["name"]))
-  if synchronization not in (vs.VariableSynchronization.ON_READ,
-                             vs.VariableSynchronization.ON_WRITE,
-                             vs.VariableSynchronization.AUTO):
+  elif synchronization == vs.VariableSynchronization.ON_READ:
+    is_sync_on_read = True
+  elif synchronization in (vs.VariableSynchronization.ON_WRITE,
+                           vs.VariableSynchronization.AUTO):
+    # `AUTO` synchronization defaults to `ON_WRITE`.
+    is_sync_on_read = False
+  else:
     raise ValueError(
         "Invalid variable synchronization mode: %s for variable: %s" %
         (synchronization, kwargs["name"]))
-  return synchronization
 
-
-def _validate_aggregation(kwargs):
   aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
 
   if aggregation not in (vs.VariableAggregation.NONE,
@@ -273,33 +281,6 @@
                          vs.VariableAggregation.ONLY_FIRST_REPLICA):
     raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
                      (aggregation, kwargs["name"]))
-  return aggregation
-
-
-def _get_variable_policy_class(synchronization, aggregation, policy_mapping):
-  if synchronization == vs.VariableSynchronization.AUTO:
-    if aggregation == vs.VariableAggregation.NONE:
-      # Use AutoPolicy.
-      return policy_mapping.get(synchronization)
-    else:
-      # Revert to OnWritePolicy
-      return policy_mapping.get(vs.VariableSynchronization.ON_WRITE)
-  return policy_mapping.get(synchronization)
-
-
-def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
-                             policy_mapping, **kwargs):
-  """Create distributed variables with given synchronization and aggregation."""
-  # Figure out what collections this variable should be added to.
-  # We'll add the MirroredVariable to those collections instead.
-  var_collections = kwargs.pop("collections", None)
-  if var_collections is None:
-    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
-  kwargs["collections"] = []
-
-  synchronization = _get_and_validate_synchronization(kwargs)
-  aggregation = _validate_aggregation(kwargs)
-  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
 
   # Ignore user-specified caching device, not needed for mirrored variables.
   kwargs.pop("caching_device", None)
@@ -309,15 +290,8 @@
   # here.
   with tape.stop_recording():
     value_list = real_mirrored_creator(**kwargs)
-    if use_var_policy:
-      var_policy_cls = _get_variable_policy_class(synchronization, aggregation,
-                                                  policy_mapping)
-      var_policy = var_policy_cls(aggregation=aggregation)
-      var_cls = class_mapping.get("VariableClass")
-      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
-    else:
-      var_cls = class_mapping.get(synchronization)
-      result = var_cls(strategy, value_list, aggregation)
+    var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
+    result = var_cls(strategy, value_list, aggregation)
     # Install the created DistributedVariable as _distributed_container property
     # of the underlying variables, to make it easy to map back to the container.
     for v in result.values:
@@ -350,55 +324,3 @@
     ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
 
   return result
-
-
-# Utility functions
-# Return True if the Value is Mirrored or the Variable is replicated and kept in
-# sync.
-def is_mirrored(val):
-  if isinstance(val, values_lib.DistributedVariable):
-    if val._policy:  # pylint: disable=protected-access
-      return val._policy._is_mirrored()  # pylint: disable=protected-access
-  return isinstance(val, values_lib.Mirrored)
-
-
-def is_sync_on_read(val):
-  if isinstance(val, values_lib.DistributedVariable):
-    if val._policy:  # pylint: disable=protected-access
-      return not val._policy._is_mirrored()  # pylint: disable=protected-access
-  return not isinstance(val, values_lib.Mirrored)
-
-# The following mapping indicates the policy that you must use for a given
-# variable `synchronization` and `aggregation` pair.
-# AutoPolicy is used for:
-# (synchronization=Auto, aggregation=None)
-# OnWritePolicy is used for:
-# (synchronization=Auto, aggregation=SUM,MEAN,ONLY_FIRST_REPLICA)
-# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
-# OnReadPolicy is used for:
-# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
-VARIABLE_POLICY_MAPPING = {
-    vs.VariableSynchronization.AUTO: values_lib.AutoPolicy,
-    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
-    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
-}
-
-VARIABLE_CLASS_MAPPING = {
-    "VariableClass": values_lib.DistributedVariable,
-    vs.VariableSynchronization.AUTO: values_lib.MirroredVariable,
-    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
-    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
-}
-
-TPU_VARIABLE_POLICY_MAPPING = {
-    vs.VariableSynchronization.AUTO: tpu_values_lib.TPUAutoPolicy,
-    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
-    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
-}
-
-TPU_VARIABLE_CLASS_MAPPING = {
-    "VariableClass": tpu_values_lib.TPUDistributedVariable,
-    vs.VariableSynchronization.AUTO: tpu_values_lib.TPUMirroredVariable,
-    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
-    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
-}
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 5323f61..b424f79 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -319,9 +319,6 @@
     if ops.executing_eagerly_outside_functions():
       self.experimental_enable_get_next_as_optional = True
 
-    # Flag to turn on VariablePolicy.
-    self._use_var_policy = False
-
   def _initialize_strategy(self, devices):
     # The _initialize_strategy method is intended to be used by distribute
     # coordinator as well.
@@ -465,8 +462,7 @@
 
     return distribute_utils.create_mirrored_variable(
         self._container_strategy(), _real_mirrored_creator,
-        distribute_utils.VARIABLE_CLASS_MAPPING,
-        distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
+        values.MirroredVariable, values.SyncOnReadVariable, **kwargs)
 
   def _validate_colocate_with_variable(self, colocate_with_variable):
     distribute_utils.validate_colocate_distributed_variable(
@@ -632,10 +628,10 @@
     return self._cross_device_ops or self._inferred_cross_device_ops
 
   def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
-    if (distribute_utils.is_mirrored(value) and
+    if (isinstance(value, values.Mirrored) and
         reduce_op == reduce_util.ReduceOp.MEAN):
       return value
-    assert not distribute_utils.is_mirrored(value)
+    assert not isinstance(value, values.Mirrored)
     if not isinstance(value, values.DistributedValues):
       # This function handles reducing values that are not PerReplica or
       # Mirrored values. For example, the same value could be present on all
@@ -690,12 +686,10 @@
 
   def read_var(self, replica_local_var):
     """Read the aggregate value of a replica-local variable."""
-    # pylint: disable=protected-access
-    if values._is_sync_on_read(replica_local_var):
-      return replica_local_var._get_cross_replica()
-    assert values._is_mirrored(replica_local_var)
-    return array_ops.identity(replica_local_var._get())
-    # pylint: enable=protected-access
+    if isinstance(replica_local_var, values.SyncOnReadVariable):
+      return replica_local_var._get_cross_replica()  # pylint: disable=protected-access
+    assert isinstance(replica_local_var, values.Mirrored)
+    return array_ops.identity(replica_local_var._get())  # pylint: disable=protected-access
 
   def _local_results(self, val):
     if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index 03d697f..8e7d674 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -379,7 +379,8 @@
     with distribution.scope():
       with self.assertRaisesRegex(
           ValueError, "`NONE` variable synchronization mode is not "
-          "supported with "):
+          "supported with `Mirrored` distribution strategy. Please change "
+          "the `synchronization` for variable: v"):
         variable_scope.get_variable(
             "v", [1],
             synchronization=variable_scope.VariableSynchronization.NONE)
@@ -388,7 +389,8 @@
     with distribution.scope():
       with self.assertRaisesRegex(
           ValueError, "`NONE` variable synchronization mode is not "
-          "supported with "):
+          "supported with `Mirrored` distribution strategy. Please change "
+          "the `synchronization` for variable: v"):
         variable_scope.variable(
             1.0,
             name="v",
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index bad6e6a..8e5ef06 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -544,9 +544,6 @@
           context.async_wait()
       atexit.register(async_wait)
 
-    # Flag to turn on VariablePolicy
-    self._use_var_policy = False
-
   def _validate_colocate_with_variable(self, colocate_with_variable):
     distribute_utils. validate_colocate(colocate_with_variable, self)
 
@@ -873,8 +870,8 @@
 
     return distribute_utils.create_mirrored_variable(
         self._container_strategy(), _real_mirrored_creator,
-        distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
-        distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
+        tpu_values.TPUMirroredVariable, tpu_values.TPUSyncOnReadVariable,
+        **kwargs)
 
   def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.DistributedValues) or
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index ce6d2e7..3388553 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -197,58 +197,10 @@
   return None
 
 
-class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
-  """DistributedVariable subclass for TPUStrategy."""
-
-  def _is_mirrored(self):
-    self._policy._is_mirrored()  # pylint: disable=protected-access
-
-  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
-    return self._policy.assign_sub(
-        self, value, use_locking=use_locking, name=name, read_value=read_value)
-
-  def assign_add(self, value, use_locking=False, name=None, read_value=True):
-    return self._policy.assign_add(
-        self, value, use_locking=use_locking, name=name, read_value=read_value)
-
-  def assign(self, value, use_locking=False, name=None, read_value=True):
-    return self._policy.assign(
-        self, value, use_locking=use_locking, name=name, read_value=read_value)
-
-  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_sub(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_add(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_add(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_mul(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_div(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_div(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_min(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_min(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_max(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_max(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-  def scatter_update(self, sparse_delta, use_locking=False, name=None):
-    return self._policy.scatter_update(
-        self, sparse_delta, use_locking=use_locking, name=name)
-
-
 class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
   """Holds a map from replica to TPU variables whose values are kept in sync."""
 
-  def assign_sub(self, value, use_locking=False, name=None,
-                 read_value=True):
+  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
     if (enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
       return _make_raw_assign_fn(
@@ -258,11 +210,17 @@
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_sub(self, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
 
-  def assign_add(self, value, use_locking=False, name=None,
-                 read_value=True):
+    assign_sub_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_sub_variable_op)
+    return self._update(
+        update_fn=assign_sub_fn,
+        value=value,
+        use_locking=use_locking,
+        name=name,
+        read_value=read_value)
+
+  def assign_add(self, value, use_locking=False, name=None, read_value=True):
     if (enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
       return _make_raw_assign_fn(
@@ -272,21 +230,34 @@
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_add(self, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+
+    assign_add_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_add_variable_op)
+    return self._update(
+        update_fn=assign_add_fn,
+        value=value,
+        use_locking=use_locking,
+        name=name,
+        read_value=read_value)
 
   def assign(self, value, use_locking=False, name=None, read_value=True):
     if (enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_variable_op)(
-              self,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign(self, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
+      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
+          self,
+          value=value,
+          use_locking=use_locking,
+          name=name,
+          read_value=read_value)
+
+    assign_fn = _make_raw_assign_fn(
+        gen_resource_variable_ops.assign_variable_op)
+    return self._update(
+        update_fn=assign_fn,
+        value=value,
+        use_locking=use_locking,
+        name=name,
+        read_value=read_value)
 
   def scatter_sub(self, *args, **kwargs):
     raise NotImplementedError
@@ -341,220 +312,3 @@
 
   def _is_mirrored(self):
     return False
-
-
-# Common method between AutoPolicy, OnWrite and Mirrored variables.
-def assign_sub(var, value, use_locking=False, name=None, read_value=True):
-  assign_sub_fn = _make_raw_assign_fn(
-      gen_resource_variable_ops.assign_sub_variable_op)
-  return var._update(  # pylint: disable=protected-access
-      update_fn=assign_sub_fn,
-      value=value,
-      use_locking=use_locking,
-      name=name,
-      read_value=read_value)
-
-
-def assign_add(var, value, use_locking=False, name=None, read_value=True):
-  assign_add_fn = _make_raw_assign_fn(
-      gen_resource_variable_ops.assign_add_variable_op)
-  return var._update(  # pylint: disable=protected-access
-      update_fn=assign_add_fn,
-      value=value,
-      use_locking=use_locking,
-      name=name,
-      read_value=read_value)
-
-
-def assign(var, value, use_locking=False, name=None, read_value=True):
-  assign_fn = _make_raw_assign_fn(
-      gen_resource_variable_ops.assign_variable_op)
-  return var._update(  # pylint: disable=protected-access
-      update_fn=assign_fn,
-      value=value,
-      use_locking=use_locking,
-      name=name,
-      read_value=read_value)
-
-
-class TPUAutoPolicy(values.AutoPolicy):
-  """Policy defined for `tf.VariableSynchronization.AUTO` synchronization.
-
-  This policy is created when `synchronization` is set to
-  `tf.VariableSynchronization.AUTO` and `aggregation` is set to
-  `tf.VariableAggregation.NONE` when creating a `tf.Variable` in `tf.distribute`
-  scope.
-  """
-
-  def assign_sub(self, var, value, use_locking=False, name=None,
-                 read_value=True):
-    if enclosing_tpu_context():
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_sub_variable_op)(
-              var,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign_sub(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
-
-  def assign_add(self, var, value, use_locking=False, name=None,
-                 read_value=True):
-    if enclosing_tpu_context():
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_add_variable_op)(
-              var,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign_add(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
-
-  def assign(self, var, value, use_locking=False, name=None, read_value=True):
-    if enclosing_tpu_context():
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_variable_op)(
-              var,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign(var, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
-
-  def scatter_sub(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_add(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_max(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_min(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_mul(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_div(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_update(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def _is_mirrored(self):
-    return True
-
-
-class TPUOnWritePolicy(values.OnWritePolicy):
-  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
-
-  This policy is created when the following `synchronization` and
-  `aggregation` parameters are specified when creating a `tf.Variable` in
-  `tf.distribute` scope:
-  * `synchronization` is equal to `tf.VariableSynchronization.AUTO` and
-  aggregation can be any of the following `tf.VariableAggregation` enum
-  values such as `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
-  * `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` and
-  aggregation can be any of the following `tf.VariableAggregation` enum
-  values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
-  """
-
-  def assign_sub(self, var, value, use_locking=False, name=None,
-                 read_value=True):
-    return assign_sub(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
-
-  def assign_add(self, var, value, use_locking=False, name=None,
-                 read_value=True):
-    return assign_add(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
-
-  def assign(self, var, value, use_locking=False, name=None, read_value=True):
-    return assign(var, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
-
-  def scatter_sub(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_add(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_max(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_min(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_mul(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_div(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_update(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def _is_mirrored(self):
-    return True
-
-
-class TPUOnReadPolicy(values.OnReadPolicy):
-  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
-
-  This policy is created when `synchronization` is set to
-  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
-  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
-  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
-  scope.
-  """
-
-  def assign_sub(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
-      return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
-                                                            **kwargs)
-
-  def assign_add(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
-      return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_add_variable_op)(var, *args,
-                                                            **kwargs)
-
-  def assign(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
-      return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
-    else:
-      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
-          var, *args, **kwargs)
-
-  def _is_mirrored(self):
-    return False
-
-  def scatter_sub(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_add(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_max(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_min(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_mul(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_div(self, *args, **kwargs):
-    raise NotImplementedError
-
-  def scatter_update(self, *args, **kwargs):
-    raise NotImplementedError
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 7dedbee..50cd8d7 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -700,49 +700,49 @@
 
   def scatter_sub(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_sub(
+      self._policy.scatter_sub(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_sub(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_add(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_add(
+      self._policy.scatter_add(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_add(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_mul(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_mul(
+      self._policy.scatter_mul(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_mul(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_div(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_div(
+      self._policy.scatter_div(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_div(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_min(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_min(
+      self._policy.scatter_min(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_min(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_max(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_max(
+      self._policy.scatter_max(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_max(
         self, sparse_delta, use_locking=use_locking, name=name)
 
   def scatter_update(self, sparse_delta, use_locking=False, name=None):
     if self._policy:
-      return self._policy.scatter_update(
+      self._policy.scatter_update(
           self, sparse_delta, use_locking=use_locking, name=name)
     return values_util.scatter_update(
         self, sparse_delta, use_locking=use_locking, name=name)
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index e445c11..1c09073 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 
 import copy
+import itertools
 import os
 
 from absl.testing import parameterized
@@ -29,12 +30,14 @@
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribute_utils
+from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import packed_distributed_variable as packed
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import test_util as ds_test_util
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values as values_lib
+from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
@@ -48,56 +51,19 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.saved_model import save_context
 from tensorflow.python.saved_model import save_options
+from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training.tracking import util as trackable_utils
 from tensorflow.python.types import core
 from tensorflow.python.util import nest
 
 
-def _device_str(d):
-  return "/device:GPU:" + str(d)
-
-
-def _nested_value(d):
-  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
-
-
-def _make_mirrored_val(init_val=5.0):
-  v = []
-  devices = ["/device:GPU:0", "/device:CPU:0"]
-  for d, _ in zip(devices, ["v", "v/replica"]):
-    with ops.device(d):
-      v.append(constant_op.constant(init_val))
-  return values_lib.Mirrored(v)
-
-
-def _make_mirrored():
-  v = []
-  devices = ["/device:GPU:0", "/device:CPU:0"]
-  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
-    with ops.device(d):
-      v.append(variable_scope.get_variable(
-          name=n, initializer=init, use_resource=True))
-  mirrored = values_lib.MirroredVariable(
-      None, v, variable_scope.VariableAggregation.SUM)
-  return mirrored
-
-
-def mirrored_and_tpu_strategy_combinations():
-  return combinations.combine(
-      distribution=[
-          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-          strategy_combinations.tpu_strategy,
-          strategy_combinations.tpu_strategy_packed_var,
-      ],
-      mode=["graph", "eager"])
-
-
 class DistributedValuesTest(test.TestCase, parameterized.TestCase):
 
   def testGetEager(self):
@@ -397,6 +363,45 @@
     self.assertEqual(v.x, v_deep_copy.x)
 
 
+def _device_str(d):
+  return "/device:GPU:" + str(d)
+
+
+def _nested_value(d):
+  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
+
+
+def _make_mirrored_val(init_val=5.0):
+  v = []
+  devices = ["/device:GPU:0", "/device:CPU:0"]
+  for d, _ in zip(devices, ["v", "v/replica"]):
+    with ops.device(d):
+      v.append(constant_op.constant(init_val))
+  return values_lib.Mirrored(v)
+
+
+def _make_mirrored():
+  v = []
+  devices = ["/device:GPU:0", "/device:CPU:0"]
+  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
+    with ops.device(d):
+      v.append(variable_scope.get_variable(
+          name=n, initializer=init, use_resource=True))
+  mirrored = values_lib.MirroredVariable(
+      None, v, variable_scope.VariableAggregation.SUM)
+  return mirrored
+
+
+def mirrored_and_tpu_strategy_combinations():
+  return combinations.combine(
+      distribution=[
+          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          strategy_combinations.tpu_strategy,
+          strategy_combinations.tpu_strategy_packed_var,
+      ],
+      mode=["graph", "eager"])
+
+
 @combinations.generate(
     combinations.combine(
         distribution=[
@@ -791,6 +796,507 @@
     save_path = self._save_normal()
     self._restore_mirrored(save_path)
 
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_one_gpu,
+          ],
+          mode=["graph"]))
+  def testFetchAMirroredVariable(self, distribution):
+    with self.session(graph=ops.Graph()) as sess, distribution.scope():
+      with ops.device("/device:GPU:0"):
+        v = variable_scope.get_variable(
+            name="v", initializer=1., use_resource=True)
+      mirrored = values_lib.MirroredVariable(
+          distribution, (v,), variable_scope.VariableAggregation.MEAN)
+      sess.run(variables_lib.global_variables_initializer())
+      sess.run({"complicated": mirrored})
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["eager"]))
+  def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(1.0, name="foo")
+
+    @def_function.function
+    def mytest():
+      def model_fn():
+        v.assign(5.0)
+        return v.read_value()
+
+      return distribution.run(model_fn)
+
+    mytest()
+    self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_one_cpu,
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["graph", "eager"]))
+  def testValueInReplicaContext(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          1., aggregation=variables_lib.VariableAggregation.MEAN)
+      self.evaluate(variables_lib.global_variables_initializer())
+
+      @def_function.function
+      def f():
+        with ops.control_dependencies([v.assign_add(1.)]):
+          return v.value()
+
+      results = self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(f)))
+      for value in results:
+        self.assertEqual(2., value)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_one_cpu,
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["graph", "eager"]))
+  def testAssignOutOfScope(self, distribution):
+    with distribution.scope():
+      mirrored = variables_lib.Variable(1.)
+    self.evaluate(mirrored.assign(3.))
+    self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
+    for component in mirrored.values:
+      self.assertEqual(self.evaluate(component.read_value()), 3.)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testAssignAggregationMeanDTypeNonFloat(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          1,
+          aggregation=variable_scope.VariableAggregation.MEAN,
+          dtype=dtypes.int32)
+    self.evaluate(v.initializer)
+
+    @def_function.function
+    def assign():
+      ctx = distribution_strategy_context.get_replica_context()
+      return v.assign(ctx.replica_id_in_sync_group)
+
+    # disallow assign() with distributed value in replica context.
+    with self.assertRaisesRegex(ValueError,
+                                "Cannot update non-float variables"):
+      self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(assign)))
+
+    # allow assign() with same value in replica context.
+    @def_function.function
+    def assign_same():
+      return v.assign(2)
+
+    self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(assign_same)))
+    self.assertEqual(self.evaluate(v.read_value()), 2)
+
+    # allow assign() with mirrored variable in replica context.
+    with distribution.scope():
+      v2 = variables_lib.Variable(
+          3,
+          aggregation=variable_scope.VariableAggregation.SUM,
+          dtype=dtypes.int32)
+    self.evaluate(v2.initializer)
+
+    @def_function.function
+    def assign_mirrored():
+      return v.assign(v2)
+
+    self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(assign_mirrored)))
+    self.assertEqual(self.evaluate(v.read_value()), 3)
+
+    # allow assign() in cross replica context.
+    with distribution.scope():
+      self.evaluate(v.assign(4))
+      self.assertEqual(self.evaluate(v.read_value()), 4)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["eager"]))
+  def testInitializedToSameValueInsideEagerRun(self, distribution):
+    v = [None]
+
+    @def_function.function
+    def step():
+
+      def f():
+        if v[0] is None:
+          v[0] = variables_lib.Variable(random_ops.random_normal([]))
+
+      distribution.run(f)
+
+    context.set_global_seed(None)
+    step()
+    vals = self.evaluate(v[0].values)
+    self.assertAllEqual(vals[0], vals[1])
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_one_cpu,
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["graph", "eager"]))
+  def testAggregationOnlyFirstReplica(self, distribution):
+    with distribution.scope():
+      v = variable_scope.variable(
+          15.,
+          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    @def_function.function
+    def assign():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      return v.assign(math_ops.cast(replica_id, dtypes.float32))
+    per_replica_results = self.evaluate(distribution.experimental_local_results(
+        distribution.run(assign)))
+    # The per-replica values should always match the first replicas value.
+    self.assertAllEqual(
+        array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
+        per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+          ],
+          mode=["eager"]))
+  def testInitScope(self, distribution):
+
+    class C(object):
+      pass
+
+    obj = C()
+    obj.w = None
+    obj.v = None
+
+    @def_function.function
+    def assign():
+      with ops.init_scope():
+        if obj.w is None:
+          obj.w = variables_lib.Variable(
+              0, aggregation=variables_lib.VariableAggregation.MEAN)
+          obj.v = variables_lib.Variable(
+              obj.w.read_value(),
+              aggregation=variables_lib.VariableAggregation.MEAN)
+
+      return obj.v.assign_add(2)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(distribution.run(assign)))
+    self.assertAllEqual([2, 2], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+          ],
+          mode=["eager"]))
+  def testOperatorOverride(self, distribution):
+
+    with distribution.scope():
+      v = variable_scope.variable(
+          1, aggregation=variables_lib.VariableAggregation.MEAN)
+
+    self.assertEqual(2, self.evaluate(v + 1))
+
+    @def_function.function
+    def add():
+      return v + 1
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(distribution.run(add)))
+    self.assertAllEqual([2, 2], per_replica_results)
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testAssignAdd(self, distribution):
+    with distribution.scope():
+      v = variable_scope.variable(
+          1, aggregation=variables_lib.VariableAggregation.MEAN)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    @def_function.function
+    def assign():
+      return v.assign_add(2)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(distribution.run(assign)))
+    # The per-replica values should always match the first replicas value.
+    self.assertAllEqual([3, 3], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterSub(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
+    self.evaluate(v.initializer)
+
+    @def_function.function
+    def scatter_sub():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.stack([
+              math_ops.cast(replica_id, dtypes.float32),
+              math_ops.cast(replica_id + 1, dtypes.float32)
+          ]),
+          indices=array_ops.stack([replica_id, replica_id + 1]),
+          dense_shape=(3,))
+      return v.scatter_sub(value)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_sub)))
+    self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterAdd(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+    self.evaluate(v.initializer)
+
+    @def_function.function
+    def scatter_add():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.stack([replica_id, replica_id + 1]),
+          indices=array_ops.stack([replica_id, replica_id + 1]),
+          dense_shape=(3,))
+      return v.scatter_add(value)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_add)))
+    self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterDiv(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
+    self.evaluate(v.initializer)
+
+    @def_function.function
+    def scatter_div():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.reshape(replica_id + 2, [1]),
+          indices=array_ops.reshape(replica_id, [1]),
+          dense_shape=(3,))
+      return v.scatter_div(value)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_div)))
+    self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterMul(self, distribution):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
+    self.evaluate(v.initializer)
+
+    @def_function.function
+    def scatter_mul():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.reshape(
+              math_ops.cast(replica_id + 2, dtypes.float32), [1]),
+          indices=array_ops.reshape(replica_id, [1]),
+          dense_shape=(3,))
+      return v.scatter_mul(value)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_mul)))
+    self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterMin(self, distribution):
+    with distribution.scope():
+      v1 = variables_lib.Variable(
+          [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
+      v2 = variables_lib.Variable(
+          [0, 2, 0],
+          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    @def_function.function
+    def scatter_min(v):
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.identity([1]),
+          indices=array_ops.identity([1]),
+          dense_shape=(3,))
+      return v.scatter_min(value)
+
+    with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
+      self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(scatter_min, args=(v1,))))
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_min, args=(v2,))))
+    self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterMax(self, distribution):
+    with distribution.scope():
+      v1 = variables_lib.Variable(
+          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+      v2 = variables_lib.Variable(
+          [0, 0, 0],
+          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    @def_function.function
+    def scatter_max(v):
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.identity([1]),
+          indices=array_ops.identity([0]),
+          dense_shape=(3,))
+      return v.scatter_max(value)
+
+    with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
+      self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(scatter_max, args=(v1,))))
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_max, args=(v2,))))
+    self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterUpdate(self, distribution):
+    with distribution.scope():
+      v1 = variables_lib.Variable(
+          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
+      v2 = variables_lib.Variable(
+          [0, 0, 0],
+          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    @def_function.function
+    def scatter_update(v):
+      value = indexed_slices.IndexedSlices(
+          values=array_ops.identity([3]),
+          indices=array_ops.identity([1]),
+          dense_shape=(3,))
+      return v.scatter_update(value)
+
+    with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
+      self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(scatter_update, args=(v1,))))
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(
+            distribution.run(scatter_update, args=(v2,))))
+    self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+          ],
+          mode=["graph", "eager"]))
+  def testScatterOpsInCrossReplicaContext(self, distribution):
+    with distribution.scope():
+      v1 = variables_lib.Variable(
+          [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
+      v2 = variables_lib.Variable([1, 1, 1])
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    value = indexed_slices.IndexedSlices(
+        values=array_ops.identity([2]),
+        indices=array_ops.identity([0]),
+        dense_shape=(3,))
+    with distribution.scope():
+      self.evaluate(v1.scatter_add(value))
+      self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
+
+      self.evaluate(v2.scatter_min(value))
+      self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
+
 
 _TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
 
@@ -815,6 +1321,38 @@
   return v, replica_local
 
 
+class SyncOnReadVariablePropertiesTest(test.TestCase):
+
+  config = config_pb2.ConfigProto()
+  config.allow_soft_placement = True
+
+  @test_util.run_in_graph_and_eager_modes(config=config)
+  def testProperties(self):
+    if context.num_gpus() < 1 and context.executing_eagerly():
+      self.skipTest("A GPU is not available for this test in eager mode.")
+    v, replica_local = _make_replica_local(
+        variable_scope.VariableAggregation.SUM)
+
+    self.assertEqual(v[0].constraint, replica_local.constraint)
+    self.assertEqual(v[0].name, replica_local.name)
+    self.assertEqual(v[0].dtype, replica_local.dtype)
+    self.assertEqual(v[0].shape, replica_local.shape)
+    self.assertEqual(variable_scope.VariableAggregation.SUM,
+                     replica_local.aggregation)
+
+  @test_util.run_v2_only
+  def testCanPassToDefFun(self):
+    @def_function.function
+    def add1(x):
+      return x + 1
+
+    v = variable_scope.get_variable(
+        name="v", initializer=[1.], use_resource=True)
+    replica_local = values_lib.SyncOnReadVariable(
+        None, (v,), variable_scope.VariableAggregation.MEAN)
+    self.assertEqual(2., self.evaluate(add1(replica_local)))
+
+
 # TODO(b/144432582): Add variable aggregation type to combinations to simplify
 # tests.
 def strategy_and_run_tf_function_combinations():
@@ -851,35 +1389,6 @@
     save_path, _ = self._save_return_saver(sess, var)
     return save_path
 
-  config = config_pb2.ConfigProto()
-  config.allow_soft_placement = True
-
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testProperties(self):
-    if context.num_gpus() < 1 and context.executing_eagerly():
-      self.skipTest("A GPU is not available for this test in eager mode.")
-    v, replica_local = _make_replica_local(
-        variable_scope.VariableAggregation.SUM)
-
-    self.assertEqual(v[0].constraint, replica_local.constraint)
-    self.assertEqual(v[0].name, replica_local.name)
-    self.assertEqual(v[0].dtype, replica_local.dtype)
-    self.assertEqual(v[0].shape, replica_local.shape)
-    self.assertEqual(variable_scope.VariableAggregation.SUM,
-                     replica_local.aggregation)
-
-  @test_util.run_v2_only
-  def testCanPassToDefFun(self):
-    @def_function.function
-    def add1(x):
-      return x + 1
-
-    v = variable_scope.get_variable(
-        name="v", initializer=[1.], use_resource=True)
-    replica_local = values_lib.SyncOnReadVariable(
-        None, (v,), variable_scope.VariableAggregation.MEAN)
-    self.assertEqual(2., self.evaluate(add1(replica_local)))
-
   @combinations.generate(mirrored_and_tpu_strategy_combinations())
   def testTensorConversion(self, distribution):
     with context.graph_mode():
@@ -1076,6 +1585,453 @@
     save_path = self._save_normal()
     self._restore_replica_local_sum(save_path, distribution)
 
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testAssign(self, distribution, experimental_run_tf_function):
+
+    def assign(fn, v, update_value, cross_replica):
+      update_fn = lambda: getattr(v, fn)(update_value)
+      if cross_replica:
+        return update_fn()
+      else:
+        if experimental_run_tf_function:
+          update_fn = def_function.function(update_fn)
+        return distribution.experimental_local_results(
+            distribution.run(update_fn))
+
+    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
+    aggregations = [
+        variables_lib.VariableAggregation.NONE,
+        variables_lib.VariableAggregation.SUM,
+        variables_lib.VariableAggregation.MEAN,
+        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+    ]
+    options = list(
+        x for x in itertools.product(updates, aggregations, [True, False]))
+    for update, aggregation, cross_replica in options:
+      # VariableAggregation.SUM in cross-replica mode is tested below,
+      # VariableAggregation.NONE in cross-replica mode is not supported.
+      if cross_replica and aggregation in [
+          variables_lib.VariableAggregation.SUM,
+          variables_lib.VariableAggregation.NONE,
+      ]:
+        continue
+      with distribution.scope():
+        v = variable_scope.variable(
+            0.,
+            synchronization=variables_lib.VariableSynchronization.ON_READ,
+            aggregation=aggregation)
+      self.evaluate(variables_lib.global_variables_initializer())
+      fn, update_value = update
+      self.evaluate(assign(fn, v, update_value, cross_replica))
+      for component in v._values:
+        self.assertAllEqual(self.evaluate(component.read_value()),
+                            self.evaluate(array_ops.ones_like(component)))
+
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testAssignDtypeConversion(self, distribution,
+                                experimental_run_tf_function):
+
+    def assign(fn, v, update_value, cross_replica):
+      update_fn = lambda: getattr(v, fn)(update_value)
+      if cross_replica:
+        return update_fn()
+      else:
+        if experimental_run_tf_function:
+          update_fn = def_function.function(update_fn)
+        return distribution.experimental_local_results(
+            distribution.run(update_fn))
+
+    updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
+    aggregations = [
+        variables_lib.VariableAggregation.NONE,
+        variables_lib.VariableAggregation.SUM,
+        variables_lib.VariableAggregation.MEAN,
+        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+    ]
+    options = list(
+        x for x in itertools.product(updates, aggregations, [True, False]))
+    for update, aggregation, cross_replica in options:
+      # VariableAggregation.SUM in cross-replica mode is tested below,
+      # VariableAggregation.NONE in cross-replica mode is not supported.
+      if cross_replica and aggregation in [
+          variables_lib.VariableAggregation.SUM,
+          variables_lib.VariableAggregation.NONE,
+      ]:
+        continue
+      with distribution.scope():
+        v = variable_scope.variable(
+            0.,
+            synchronization=variables_lib.VariableSynchronization.ON_READ,
+            aggregation=aggregation)
+      self.evaluate(variables_lib.global_variables_initializer())
+      fn, update_value = update
+      self.evaluate(assign(fn, v, update_value, cross_replica))
+      for component in v._values:
+        self.assertAllEqual(self.evaluate(component.read_value()),
+                            self.evaluate(array_ops.ones_like(component)))
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testAssignWithAggregationSum(self, distribution):
+    with distribution.scope():
+      v = variable_scope.variable(
+          0.,
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=variables_lib.VariableAggregation.SUM)
+    self.evaluate(variables_lib.global_variables_initializer())
+    self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
+    for component in v._values:
+      self.assertAllEqual(self.evaluate(component.read_value()),
+                          self.evaluate(array_ops.ones_like(component)))
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testAssignAddSubWithAggregationSum(self, distribution):
+    with distribution.scope():
+      v = variable_scope.variable(
+          0.,
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=variables_lib.VariableAggregation.SUM)
+    self.evaluate(variables_lib.global_variables_initializer())
+    with self.assertRaisesRegex(
+        ValueError, "SyncOnReadVariable does not support "):
+      self.evaluate(v.assign_add(1.))
+    with self.assertRaisesRegex(
+        ValueError, "SyncOnReadVariable does not support "):
+      self.evaluate(v.assign_sub(1.))
+
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testReadValueInReplicaContext(self, distribution,
+                                    experimental_run_tf_function):
+    aggregations = [
+        variables_lib.VariableAggregation.NONE,
+        variables_lib.VariableAggregation.SUM,
+        variables_lib.VariableAggregation.MEAN,
+        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+    ]
+    for aggregation in aggregations:
+      with distribution.scope():
+        v = variable_scope.variable(
+            0.,
+            synchronization=variables_lib.VariableSynchronization.ON_READ,
+            aggregation=aggregation)
+      self.evaluate(variables_lib.global_variables_initializer())
+      if experimental_run_tf_function:
+        read_var_fn = def_function.function(v.read_value)
+      else:
+        read_var_fn = v.read_value
+      results = self.evaluate(
+          distribution.experimental_local_results(
+              distribution.run(read_var_fn)))
+      for component, value in zip(v._values, results):
+        self.assertAllEqual(self.evaluate(component.read_value()), value)
+
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testReadValueInCrossReplicaContext(self, distribution,
+                                         experimental_run_tf_function):
+    aggregations = [
+        variables_lib.VariableAggregation.SUM,
+        variables_lib.VariableAggregation.MEAN,
+        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+    ]
+    for aggregation in aggregations:
+      if isinstance(distribution, _TPU_STRATEGIES):
+        resolver = tpu_cluster_resolver.TPUClusterResolver("")
+        tpu_strategy_util.initialize_tpu_system(resolver)
+      with distribution.scope():
+        v = variable_scope.variable(
+            0.,
+            synchronization=variables_lib.VariableSynchronization.ON_READ,
+            aggregation=aggregation)
+      self.evaluate(variables_lib.global_variables_initializer())
+
+      def assign(v=v):
+        ctx = distribution_strategy_context.get_replica_context()
+        replica_id = ctx.replica_id_in_sync_group
+        return v.assign(math_ops.cast(replica_id, dtypes.float32))
+
+      if experimental_run_tf_function:
+        assign = def_function.function(assign)
+
+      self.evaluate(
+          distribution.experimental_local_results(distribution.run(assign)))
+      num_replicas = distribution.num_replicas_in_sync
+      sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
+      if aggregation == variables_lib.VariableAggregation.SUM:
+        expected = sum_of_replica_values
+      elif aggregation == variables_lib.VariableAggregation.MEAN:
+        expected = sum_of_replica_values / num_replicas
+      else:
+        expected = 0
+      self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
+      self.assertEqual(expected, self.evaluate(v.value()), aggregation)
+      self.assertEqual(expected, self.evaluate(v), aggregation)
+      self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
+                       aggregation)
+
+  # TODO(b/145574622): Re-enable this test once ReduceOp argument is
+  # respected on GPUs.
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def disable_testAllReduce(self, distribution,
+                            experimental_run_tf_function):
+    with distribution.scope():
+      v = variable_scope.variable(
+          2.,
+          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+          aggregation=variables_lib.VariableAggregation.MEAN)
+    self.evaluate(variables_lib.global_variables_initializer())
+
+    def all_reduce():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
+                                                      dtypes.float32)
+
+    if experimental_run_tf_function:
+      all_reduce = def_function.function(all_reduce)
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(distribution.run(all_reduce)))
+    expected_result = []
+    for i in range(distribution.num_replicas_in_sync):
+      expected_result.append(2.0 * distribution.num_replicas_in_sync +
+                             1.0 * i)
+    self.assertEqual(per_replica_results, tuple(expected_result))
+
+  @combinations.generate(strategy_and_run_tf_function_combinations())
+  def testAssignPerReplicaBeforeRead(self, distribution,
+                                     experimental_run_tf_function):
+    aggregations = [
+        variables_lib.VariableAggregation.SUM,
+        variables_lib.VariableAggregation.MEAN,
+        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+    ]
+    for aggregation in aggregations:
+      with distribution.scope():
+        v = variable_scope.variable(
+            0.,
+            synchronization=variables_lib.VariableSynchronization.ON_READ,
+            aggregation=aggregation)
+      self.evaluate(variables_lib.global_variables_initializer())
+
+      def assign(var=v):
+        ctx = distribution_strategy_context.get_replica_context()
+        replica_id = ctx.replica_id_in_sync_group
+        return var.assign(math_ops.cast(replica_id, dtypes.float32))
+
+      if experimental_run_tf_function:
+        assign = def_function.function(assign)
+
+      per_replica_results = self.evaluate(
+          distribution.experimental_local_results(distribution.run(assign)))
+      expected_result = []
+      for i in range(distribution.num_replicas_in_sync):
+        expected_result.append(1.0 * i)
+      self.assertEqual(per_replica_results, tuple(expected_result))
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
+    with distribution.scope():
+      v = variable_scope.variable(
+          0.,
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=variables_lib.VariableAggregation.NONE)
+    self.evaluate(variables_lib.global_variables_initializer())
+    with self.assertRaisesRegex(
+        ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
+      self.evaluate(v.read_value())
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testInitializedToSameValueInsideEagerRun(self, distribution):
+    if not context.executing_eagerly(): self.skipTest("eager only")
+
+    v = [None]
+    @def_function.function
+    def step():
+      def f():
+        if v[0] is None:
+          v[0] = variables_lib.Variable(
+              random_ops.random_normal([]),
+              synchronization=variables_lib.VariableSynchronization.ON_READ)
+
+      distribution.run(f)
+
+    context.set_global_seed(None)
+    step()
+    vals = self.evaluate(v[0].values)
+    self.assertAllEqual(vals[0], vals[1])
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.tpu_strategy,
+          ],
+          mode=["eager"]))
+  def testOperatorOverride(self, distribution):
+
+    with distribution.scope():
+      v = variable_scope.variable(
+          0.0,
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=variables_lib.VariableAggregation.MEAN)
+
+    @def_function.function
+    def assign():
+      ctx = distribution_strategy_context.get_replica_context()
+      replica_id = ctx.replica_id_in_sync_group
+      return v.assign(math_ops.cast(replica_id, dtypes.float32))
+
+    # Assign different replicas with different values.
+    distribution.run(assign)
+
+    self.assertEqual(1.5, self.evaluate(v + 1))
+
+    @def_function.function
+    def add():
+      return v + 1
+
+    per_replica_results = self.evaluate(
+        distribution.experimental_local_results(distribution.run(add)))
+    self.assertAllEqual([1, 2], per_replica_results)
+
+
+@combinations.generate(
+    combinations.combine(
+        distribution=[
+            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+        ],
+        aggregation=[
+            variables_lib.VariableAggregation.MEAN,
+            variables_lib.VariableAggregation.SUM,
+            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+        ],
+        mode=["graph", "eager"]))
+class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
+
+  def testScatterSub(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [1., 1., 1.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
+
+  def testScatterAdd(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [1., 1., 1.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
+
+  def testScatterDiv(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [2., 6., 1.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
+
+  def testScatterMul(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [2., 1., 1.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
+
+  def testScatterMin(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [3., 4., 5.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
+
+  def testScatterMax(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [3., 4., 5.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
+
+  def testScatterUpdate(self, distribution, aggregation):
+    with distribution.scope():
+      v = variables_lib.Variable(
+          [0., 0., 0.],
+          synchronization=variables_lib.VariableSynchronization.ON_READ,
+          aggregation=aggregation)
+    self.evaluate(v.initializer)
+
+    delta = values_lib.PerReplica([
+        indexed_slices.IndexedSlices(
+            values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
+        indexed_slices.IndexedSlices(
+            values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
+    ])
+
+    with self.assertRaises(NotImplementedError):
+      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
+
 
 class MirroredTest(test.TestCase):
 
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
deleted file mode 100644
index 5866c0c..0000000
--- a/tensorflow/python/distribute/vars_test.py
+++ /dev/null
@@ -1,1270 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the distributed values library."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-from absl.testing import parameterized
-
-from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.distribute import strategy_combinations
-from tensorflow.python.distribute import tpu_strategy
-from tensorflow.python.distribute import tpu_values
-from tensorflow.python.distribute import values
-from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
-from tensorflow.python.eager import context
-from tensorflow.python.eager import def_function
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import indexed_slices
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.tpu import tpu_strategy_util
-
-
-_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
-
-
-def strategy_and_run_tf_function_combinations():
-  # Test the combination of different strategies and whether a tf.function
-  # is passed into strategy.run."""
-  return combinations.combine(
-      distribution=[
-          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-      ],
-      mode=["graph", "eager"],
-      experimental_run_tf_function=[True, False],
-      use_var_policy=[True, False]) + combinations.combine(
-          distribution=[
-              strategy_combinations.tpu_strategy,
-              strategy_combinations.tpu_strategy_packed_var,
-          ],
-          mode=["graph", "eager"],
-          experimental_run_tf_function=[True],
-          use_var_policy=[True, False])
-
-
-def strategy_with_var_policy():
-  return combinations.combine(
-      distribution=[
-          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-          strategy_combinations.tpu_strategy,
-          strategy_combinations.tpu_strategy_packed_var,
-          strategy_combinations.central_storage_strategy_with_two_gpus,
-      ],
-      mode=["graph", "eager"],
-      use_var_policy=[True, False])
-
-
-class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
-
-  @combinations.generate(
-      combinations.combine(
-          distribution=[
-              strategy_combinations.mirrored_strategy_with_one_gpu,
-          ],
-          mode=["graph"]))
-  def testFetchAMirroredVariable(self, distribution):
-    with self.session(graph=ops.Graph()) as sess, distribution.scope():
-      with ops.device("/device:GPU:0"):
-        v = variable_scope.get_variable(
-            name="v", initializer=1., use_resource=True)
-      mirrored = values.MirroredVariable(
-          distribution, (v,), variable_scope.VariableAggregation.MEAN)
-      sess.run(variables_lib.global_variables_initializer())
-      sess.run({"complicated": mirrored})
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssign(self, distribution, experimental_run_tf_function,
-                 use_var_policy):
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # assign in replica context with SUM does not make sense cause you can
-      # just do value * num replicas error is 1. is not a distributed value and
-      # is unsupported for aggregation SUM
-      if (not cross_replica and aggregation ==
-          variables_lib.VariableAggregation.SUM):
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      for component in v._values:
-        self.assertAllEqual(self.evaluate(component.read_value()),
-                            self.evaluate(array_ops.ones_like(component)))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignOnWriteVar(self, distribution, experimental_run_tf_function,
-                           use_var_policy):
-
-    with distribution.scope():
-      v_to_assign = variable_scope.variable(
-          2., aggregation=variables_lib.VariableAggregation.MEAN)
-      v_to_assign_sub = variable_scope.variable(
-          -2., aggregation=variables_lib.VariableAggregation.MEAN)
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
-               ("assign_sub", v_to_assign_sub)]
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # assign in replica context with SUM does not make sense cause you can
-      # just do value * num replicas error is 1. is not a distributed value and
-      # is unsupported for aggregation SUM
-      if aggregation == variables_lib.VariableAggregation.SUM:
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      for component in v._values:
-        self.assertAllEqual(2.0, self.evaluate(component.read_value()))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
-                              use_var_policy):
-
-    if isinstance(distribution, _TPU_STRATEGIES):
-      self.skipTest("Assigning PerReplica values is not supported. See"
-                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
-
-    with distribution.scope():
-      per_replica_value = values.PerReplica(
-          [constant_op.constant(2.0),
-           constant_op.constant(2.0)])
-      per_replica_sub_value = values.PerReplica(
-          [constant_op.constant(-2.0),
-           constant_op.constant(-2.0)])
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", per_replica_value), ("assign_add", per_replica_value),
-               ("assign_sub", per_replica_sub_value)]
-    # We don't support assigning PerReplica valus to vars in replica context
-    # with aggregation=NONE.
-    aggregations = [
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # assign in replica context with SUM does not make sense cause you can
-      # just do value * num replicas error is 1. is not a distributed value and
-      # is unsupported for aggregation SUM
-      if cross_replica:
-        # We don't support assigning PerReplica values to MirroredVariables in
-        # cross replica context
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      if aggregation == variables_lib.VariableAggregation.SUM:
-        expected = 4.0
-      else:
-        expected = 2.0
-      for component in v._values:
-        self.assertAllEqual(expected, self.evaluate(component.read_value()))
-
-  @combinations.generate(strategy_with_var_policy())
-  def testValueInReplicaContext(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          1., aggregation=variables_lib.VariableAggregation.MEAN)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-      @def_function.function
-      def f():
-        with ops.control_dependencies([v.assign_add(1.)]):
-          return v.value()
-
-      results = self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(f)))
-      for value in results:
-        self.assertEqual(2., value)
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInReplicaContext(self, distribution,
-                                    experimental_run_tf_function,
-                                    use_var_policy):
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    for aggregation in aggregations:
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      if experimental_run_tf_function:
-        read_var_fn = def_function.function(v.read_value)
-      else:
-        read_var_fn = v.read_value
-      results = self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(read_var_fn)))
-      for component, value in zip(v._values, results):
-        self.assertAllEqual(self.evaluate(component.read_value()), value)
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInCrossReplicaContext(self, distribution,
-                                         experimental_run_tf_function,
-                                         use_var_policy):
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    for aggregation in aggregations:
-      with distribution.scope():
-        v = variable_scope.variable(
-            2.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-      if experimental_run_tf_function:
-        read_var_fn = def_function.function(v.read_value)
-      else:
-        read_var_fn = v.read_value
-
-      results = read_var_fn()
-      for component in v._values:
-        self.assertEqual(self.evaluate(component.read_value()),
-                         self.evaluate(results))
-
-  @combinations.generate(strategy_with_var_policy())
-  def testAssignOutOfScope(self, distribution, use_var_policy):
-    with distribution.scope():
-      mirrored = variables_lib.Variable(1.)
-    self.evaluate(mirrored.assign(3.))
-    self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
-    for component in mirrored.values:
-      self.assertEqual(self.evaluate(component.read_value()), 3.)
-
-  @combinations.generate(strategy_with_var_policy())
-  def testAssignAggregationMeanDTypeNonFloat(self, distribution,
-                                             use_var_policy):
-    if isinstance(distribution, _TPU_STRATEGIES):
-      self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
-                    "reenabling test.")
-
-    with distribution.scope():
-      v = variables_lib.Variable(
-          1,
-          aggregation=variable_scope.VariableAggregation.MEAN,
-          dtype=dtypes.int32)
-    self.evaluate(v.initializer)
-
-    @def_function.function
-    def assign():
-      ctx = distribution_strategy_context.get_replica_context()
-      return v.assign(ctx.replica_id_in_sync_group)
-
-    # disallow assign() with distributed value in replica context.
-    with self.assertRaisesRegex(ValueError,
-                                "Cannot update non-float variables"):
-      self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(assign)))
-
-    # allow assign() with same value in replica context.
-    @def_function.function
-    def assign_same():
-      return v.assign(2)
-
-    self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(assign_same)))
-    self.assertEqual(self.evaluate(v.read_value()), 2)
-
-    # allow assign() with mirrored variable in replica context.
-    with distribution.scope():
-      v2 = variables_lib.Variable(
-          3,
-          aggregation=variable_scope.VariableAggregation.SUM,
-          dtype=dtypes.int32)
-    self.evaluate(v2.initializer)
-
-    @def_function.function
-    def assign_mirrored():
-      return v.assign(v2)
-
-    self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(assign_mirrored)))
-    self.assertEqual(self.evaluate(v.read_value()), 3)
-
-    # allow assign() in cross replica context.
-    with distribution.scope():
-      self.evaluate(v.assign(4))
-      self.assertEqual(self.evaluate(v.read_value()), 4)
-
-  @combinations.generate(strategy_with_var_policy())
-  def testInitializedToSameValueInsideEagerRun(self, distribution,
-                                               use_var_policy):
-    if not context.executing_eagerly(): self.skipTest("eager only test")
-    v = [None]
-
-    @def_function.function
-    def step():
-
-      def f():
-        if v[0] is None:
-          v[0] = variables_lib.Variable(random_ops.random_normal([]))
-
-      distribution.run(f)
-
-    context.set_global_seed(None)
-    step()
-    vals = self.evaluate(v[0].values)
-    self.assertAllEqual(vals[0], vals[1])
-
-  @combinations.generate(strategy_with_var_policy())
-  def testAggregationOnlyFirstReplica(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variable_scope.variable(
-          15.,
-          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
-          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    @def_function.function
-    def assign():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      return v.assign(math_ops.cast(replica_id, dtypes.float32))
-    per_replica_results = self.evaluate(distribution.experimental_local_results(
-        distribution.run(assign)))
-    # The per-replica values should always match the first replicas value.
-    self.assertAllEqual(
-        array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
-        per_replica_results)
-
-  @combinations.generate(strategy_with_var_policy())
-  def testInitScope(self, distribution, use_var_policy):
-    if not context.executing_eagerly(): self.skipTest("eager only")
-
-    class C(object):
-      pass
-
-    obj = C()
-    obj.w = None
-    obj.v = None
-
-    @def_function.function
-    def assign():
-      with ops.init_scope():
-        if obj.w is None:
-          obj.w = variables_lib.Variable(
-              0, aggregation=variables_lib.VariableAggregation.MEAN)
-          obj.v = variables_lib.Variable(
-              obj.w.read_value(),
-              aggregation=variables_lib.VariableAggregation.MEAN)
-          self.evaluate(variables_lib.global_variables_initializer())
-
-      return obj.v.assign_add(2)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(distribution.run(assign)))
-    self.assertAllEqual([2, 2], per_replica_results)
-
-  @combinations.generate(strategy_with_var_policy())
-  def testOperatorOverride(self, distribution, use_var_policy):
-
-    with distribution.scope():
-      v = variable_scope.variable(
-          1, aggregation=variables_lib.VariableAggregation.MEAN)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-    self.assertEqual(2, self.evaluate(v + 1))
-
-    @def_function.function
-    def add():
-      return v + 1
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(distribution.run(add)))
-    self.assertAllEqual([2, 2], per_replica_results)
-
-
-@combinations.generate(
-    combinations.combine(
-        distribution=[
-            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-        ],
-        mode=["graph", "eager"],
-        use_var_policy=[True, False]))
-class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
-
-  def testScatterSub(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
-    self.evaluate(v.initializer)
-
-    @def_function.function
-    def scatter_sub():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.stack([
-              math_ops.cast(replica_id, dtypes.float32),
-              math_ops.cast(replica_id + 1, dtypes.float32)
-          ]),
-          indices=array_ops.stack([replica_id, replica_id + 1]),
-          dense_shape=(3,))
-      return v.scatter_sub(value)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_sub)))
-    self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
-
-  def testScatterAdd(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
-    self.evaluate(v.initializer)
-
-    @def_function.function
-    def scatter_add():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.stack([replica_id, replica_id + 1]),
-          indices=array_ops.stack([replica_id, replica_id + 1]),
-          dense_shape=(3,))
-      return v.scatter_add(value)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_add)))
-    self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
-
-  def testScatterDiv(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
-    self.evaluate(v.initializer)
-
-    @def_function.function
-    def scatter_div():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.reshape(replica_id + 2, [1]),
-          indices=array_ops.reshape(replica_id, [1]),
-          dense_shape=(3,))
-      return v.scatter_div(value)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_div)))
-    self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
-
-  def testScatterMul(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
-    self.evaluate(v.initializer)
-
-    @def_function.function
-    def scatter_mul():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.reshape(
-              math_ops.cast(replica_id + 2, dtypes.float32), [1]),
-          indices=array_ops.reshape(replica_id, [1]),
-          dense_shape=(3,))
-      return v.scatter_mul(value)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_mul)))
-    self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
-
-  def testScatterMin(self, distribution, use_var_policy):
-    with distribution.scope():
-      v1 = variables_lib.Variable(
-          [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
-      v2 = variables_lib.Variable(
-          [0, 2, 0],
-          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    @def_function.function
-    def scatter_min(v):
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.identity([1]),
-          indices=array_ops.identity([1]),
-          dense_shape=(3,))
-      return v.scatter_min(value)
-
-    with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
-      self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(scatter_min, args=(v1,))))
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_min, args=(v2,))))
-    self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
-
-  def testScatterMax(self, distribution, use_var_policy):
-    with distribution.scope():
-      v1 = variables_lib.Variable(
-          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
-      v2 = variables_lib.Variable(
-          [0, 0, 0],
-          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    @def_function.function
-    def scatter_max(v):
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.identity([1]),
-          indices=array_ops.identity([0]),
-          dense_shape=(3,))
-      return v.scatter_max(value)
-
-    with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
-      self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(scatter_max, args=(v1,))))
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_max, args=(v2,))))
-    self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
-
-  def testScatterUpdate(self, distribution, use_var_policy):
-    with distribution.scope():
-      v1 = variables_lib.Variable(
-          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
-      v2 = variables_lib.Variable(
-          [0, 0, 0],
-          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    @def_function.function
-    def scatter_update(v):
-      value = indexed_slices.IndexedSlices(
-          values=array_ops.identity([3]),
-          indices=array_ops.identity([1]),
-          dense_shape=(3,))
-      return v.scatter_update(value)
-
-    with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
-      self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(scatter_update, args=(v1,))))
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(
-            distribution.run(scatter_update, args=(v2,))))
-    self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
-
-  def testScatterOpsInCrossReplicaContext(self, distribution, use_var_policy):
-    with distribution.scope():
-      v1 = variables_lib.Variable(
-          [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
-      v2 = variables_lib.Variable([1, 1, 1])
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    value = indexed_slices.IndexedSlices(
-        values=array_ops.identity([2]),
-        indices=array_ops.identity([0]),
-        dense_shape=(3,))
-    with distribution.scope():
-      self.evaluate(v1.scatter_add(value))
-      self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
-
-      self.evaluate(v2.scatter_min(value))
-      self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
-
-
-def _make_replica_local(method, strategy=None):
-  if strategy is None:
-    devices = ("/device:GPU:0", "/device:CPU:0")
-  else:
-    devices = strategy.extended.worker_devices
-
-  v = []
-  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
-    with ops.device(d):
-      v.append(variable_scope.get_variable(
-          name=n, initializer=init, use_resource=True))
-
-  if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
-    var_cls = tpu_values.TPUSyncOnReadVariable
-  else:
-    var_cls = values.SyncOnReadVariable
-  replica_local = var_cls(strategy, v, method)
-  return v, replica_local
-
-
-class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssign(self, distribution, experimental_run_tf_function,
-                 use_var_policy):
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # VariableAggregation.SUM in cross-replica mode is tested below,
-      # VariableAggregation.NONE in cross-replica mode is not supported.
-      if cross_replica and aggregation in [
-          variables_lib.VariableAggregation.SUM,
-          variables_lib.VariableAggregation.NONE,
-      ]:
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      for component in v._values:
-        self.assertAllEqual(self.evaluate(component.read_value()),
-                            self.evaluate(array_ops.ones_like(component)))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignOnReadVar(self, distribution, experimental_run_tf_function,
-                          use_var_policy):
-
-    with distribution.scope():
-      v_to_assign = variable_scope.variable(
-          2., aggregation=variables_lib.VariableAggregation.MEAN)
-      v_to_assign_sub = variable_scope.variable(
-          -2., aggregation=variables_lib.VariableAggregation.MEAN)
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
-               ("assign_sub", v_to_assign_sub)]
-    expected_cross_replica = {
-        variables_lib.VariableAggregation.SUM: 1.0,
-        variables_lib.VariableAggregation.MEAN: 2.0,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
-    }
-    expected_replica = {
-        variables_lib.VariableAggregation.SUM: 2.0,
-        variables_lib.VariableAggregation.MEAN: 2.0,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
-    }
-    # aggregation=NONE is not supported for OnReadVariables.
-    aggregations = [
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # assign in replica context with SUM does not make sense cause you can
-      # just do value * num replicas error is 1. is not a distributed value and
-      # is unsupported for aggregation SUM
-      if aggregation == variables_lib.VariableAggregation.SUM:
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      if cross_replica:
-        for component in v._values:
-          self.assertAllEqual(expected_cross_replica.get(aggregation),
-                              self.evaluate(component.read_value()))
-      else:
-        for component in v._values:
-          self.assertAllEqual(expected_replica.get(aggregation),
-                              self.evaluate(component.read_value()))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
-                              use_var_policy):
-
-    if isinstance(distribution, _TPU_STRATEGIES):
-      self.skipTest("Assigning PerReplica values is not supported. See"
-                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
-
-    self.skipTest("We don't support assiging PerReplica values in cross "
-                  "replica context or replica context. see error in "
-                  "sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.")
-
-    with distribution.scope():
-      per_replica_value = values.PerReplica(
-          [constant_op.constant(2.0),
-           constant_op.constant(2.0)])
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", per_replica_value)]
-    # We don't support assigning PerReplica valus to vars in replica context
-    # with aggregation=NONE.
-    aggregations = [
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # assign in replica context with SUM does not make sense cause you can
-      # just do value * num replicas error is 1. is not a distributed value and
-      # is unsupported for aggregation SUM
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      # with self.assertRaisesRegex(ValueError, "Attempt to convert a value "):
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      if aggregation == variables_lib.VariableAggregation.SUM:
-        expected = 4.0
-      else:
-        expected = 2.0
-      for component in v._values:
-        self.assertAllEqual(expected, self.evaluate(component.read_value()))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignDtypeConversion(self, distribution,
-                                experimental_run_tf_function,
-                                use_var_policy):
-
-    def assign(fn, v, update_value, cross_replica):
-      update_fn = lambda: getattr(v, fn)(update_value)
-      if cross_replica:
-        return update_fn()
-      else:
-        if experimental_run_tf_function:
-          update_fn = def_function.function(update_fn)
-        return distribution.experimental_local_results(
-            distribution.run(update_fn))
-
-    updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    options = list(
-        x for x in itertools.product(updates, aggregations, [True, False]))
-    for update, aggregation, cross_replica in options:
-      # VariableAggregation.SUM in cross-replica mode is tested below,
-      # VariableAggregation.NONE in cross-replica mode is not supported.
-      if cross_replica and aggregation in [
-          variables_lib.VariableAggregation.SUM,
-          variables_lib.VariableAggregation.NONE,
-      ]:
-        continue
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      fn, update_value = update
-      self.evaluate(assign(fn, v, update_value, cross_replica))
-      for component in v._values:
-        self.assertAllEqual(self.evaluate(component.read_value()),
-                            self.evaluate(array_ops.ones_like(component)))
-
-  @combinations.generate(strategy_with_var_policy())
-  def testAssignWithAggregationSum(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variable_scope.variable(
-          0.,
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=variables_lib.VariableAggregation.SUM)
-    self.evaluate(variables_lib.global_variables_initializer())
-    self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
-    for component in v._values:
-      self.assertAllEqual(self.evaluate(component.read_value()),
-                          self.evaluate(array_ops.ones_like(component)))
-
-  @combinations.generate(strategy_with_var_policy())
-  def testAssignAddSubWithAggregationSum(self, distribution, use_var_policy):
-    with distribution.scope():
-      v = variable_scope.variable(
-          0.,
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=variables_lib.VariableAggregation.SUM)
-    self.evaluate(variables_lib.global_variables_initializer())
-    with self.assertRaisesRegex(
-        ValueError, "SyncOnReadVariable does not support "):
-      self.evaluate(v.assign_add(1.))
-    with self.assertRaisesRegex(
-        ValueError, "SyncOnReadVariable does not support "):
-      self.evaluate(v.assign_sub(1.))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInReplicaContext(self, distribution,
-                                    experimental_run_tf_function,
-                                    use_var_policy):
-    aggregations = [
-        variables_lib.VariableAggregation.NONE,
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    for aggregation in aggregations:
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-      if experimental_run_tf_function:
-        read_var_fn = def_function.function(v.read_value)
-      else:
-        read_var_fn = v.read_value
-      results = self.evaluate(
-          distribution.experimental_local_results(
-              distribution.run(read_var_fn)))
-      for component, value in zip(v._values, results):
-        self.assertAllEqual(self.evaluate(component.read_value()), value)
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testReadValueInCrossReplicaContext(self, distribution,
-                                         experimental_run_tf_function,
-                                         use_var_policy):
-    aggregations = [
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    for aggregation in aggregations:
-      if isinstance(distribution, _TPU_STRATEGIES):
-        resolver = tpu_cluster_resolver.TPUClusterResolver("")
-        tpu_strategy_util.initialize_tpu_system(resolver)
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-      def assign(v=v):
-        ctx = distribution_strategy_context.get_replica_context()
-        replica_id = ctx.replica_id_in_sync_group
-        return v.assign(math_ops.cast(replica_id, dtypes.float32))
-
-      if experimental_run_tf_function:
-        assign = def_function.function(assign)
-
-      self.evaluate(
-          distribution.experimental_local_results(distribution.run(assign)))
-      num_replicas = distribution.num_replicas_in_sync
-      sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
-      if aggregation == variables_lib.VariableAggregation.SUM:
-        expected = sum_of_replica_values
-      elif aggregation == variables_lib.VariableAggregation.MEAN:
-        expected = sum_of_replica_values / num_replicas
-      else:
-        expected = 0
-      self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
-      self.assertEqual(expected, self.evaluate(v.value()), aggregation)
-      self.assertEqual(expected, self.evaluate(v), aggregation)
-      self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
-                       aggregation)
-
-  # TODO(b/145574622): Re-enable this test once ReduceOp argument is
-  # respected on GPUs.
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def disable_testAllReduce(self, distribution,
-                            experimental_run_tf_function,
-                            use_var_policy):
-    with distribution.scope():
-      v = variable_scope.variable(
-          2.,
-          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
-          aggregation=variables_lib.VariableAggregation.MEAN)
-    self.evaluate(variables_lib.global_variables_initializer())
-
-    def all_reduce():
-      ctx = distribution_strategy_context.get_replica_context()
-      replica_id = ctx.replica_id_in_sync_group
-      return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
-                                                      dtypes.float32)
-
-    if experimental_run_tf_function:
-      all_reduce = def_function.function(all_reduce)
-
-    per_replica_results = self.evaluate(
-        distribution.experimental_local_results(distribution.run(all_reduce)))
-    expected_result = []
-    for i in range(distribution.num_replicas_in_sync):
-      expected_result.append(2.0 * distribution.num_replicas_in_sync +
-                             1.0 * i)
-    self.assertEqual(per_replica_results, tuple(expected_result))
-
-  @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignPerReplicaBeforeRead(self, distribution,
-                                     experimental_run_tf_function,
-                                     use_var_policy):
-    aggregations = [
-        variables_lib.VariableAggregation.SUM,
-        variables_lib.VariableAggregation.MEAN,
-        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-    ]
-    for aggregation in aggregations:
-      with distribution.scope():
-        v = variable_scope.variable(
-            0.,
-            synchronization=variables_lib.VariableSynchronization.ON_READ,
-            aggregation=aggregation)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-      def assign(var=v):
-        ctx = distribution_strategy_context.get_replica_context()
-        replica_id = ctx.replica_id_in_sync_group
-        return var.assign(math_ops.cast(replica_id, dtypes.float32))
-
-      if experimental_run_tf_function:
-        assign = def_function.function(assign)
-
-      per_replica_results = self.evaluate(
-          distribution.experimental_local_results(distribution.run(assign)))
-      expected_result = []
-      for i in range(distribution.num_replicas_in_sync):
-        expected_result.append(1.0 * i)
-      self.assertEqual(per_replica_results, tuple(expected_result))
-
-  @combinations.generate(strategy_with_var_policy())
-  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution,
-                                                            use_var_policy):
-    with distribution.scope():
-      v = variable_scope.variable(
-          0.,
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=variables_lib.VariableAggregation.NONE)
-    self.evaluate(variables_lib.global_variables_initializer())
-    with self.assertRaisesRegex(
-        ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
-      self.evaluate(v.read_value())
-
-  @combinations.generate(strategy_with_var_policy())
-  def testInitializedToSameValueInsideEagerRun(self, distribution,
-                                               use_var_policy):
-    if not context.executing_eagerly(): self.skipTest("eager only")
-
-    v = [None]
-    @def_function.function
-    def step():
-      def f():
-        if v[0] is None:
-          v[0] = variables_lib.Variable(
-              random_ops.random_normal([]),
-              synchronization=variables_lib.VariableSynchronization.ON_READ)
-
-      distribution.run(f)
-
-    context.set_global_seed(None)
-    step()
-    vals = self.evaluate(v[0].values)
-    self.assertAllEqual(vals[0], vals[1])
-
-  @combinations.generate(strategy_with_var_policy())
-  def testOperatorOverride(self, distribution, use_var_policy):
-
-    with distribution.scope():
-      v = variable_scope.variable(
-          0.0,
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=variables_lib.VariableAggregation.MEAN)
-      self.evaluate(variables_lib.global_variables_initializer())
-
-      @def_function.function
-      def assign():
-        ctx = distribution_strategy_context.get_replica_context()
-        replica_id = ctx.replica_id_in_sync_group
-        return v.assign(math_ops.cast(replica_id, dtypes.float32))
-
-      # Assign different replicas with different values.
-      self.evaluate(distribution.experimental_local_results(
-          distribution.run(assign)))
-      self.assertEqual(1.5, self.evaluate(v + 1))
-
-      @def_function.function
-      def add():
-        return v + 1
-
-      per_replica_results = self.evaluate(
-          distribution.experimental_local_results(distribution.run(add)))
-      self.assertAllEqual([1, 2], per_replica_results)
-
-
-@combinations.generate(
-    combinations.combine(
-        distribution=[
-            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
-        ],
-        aggregation=[
-            variables_lib.VariableAggregation.MEAN,
-            variables_lib.VariableAggregation.SUM,
-            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
-        ],
-        mode=["graph", "eager"],
-        use_var_policy=[True, False]))
-class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
-
-  def testScatterSub(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [1., 1., 1.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
-
-  def testScatterAdd(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [1., 1., 1.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
-
-  def testScatterDiv(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [2., 6., 1.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
-
-  def testScatterMul(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [2., 1., 1.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
-
-  def testScatterMin(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [3., 4., 5.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
-
-  def testScatterMax(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [3., 4., 5.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
-
-  def testScatterUpdate(self, distribution, aggregation, use_var_policy):
-    with distribution.scope():
-      v = variables_lib.Variable(
-          [0., 0., 0.],
-          synchronization=variables_lib.VariableSynchronization.ON_READ,
-          aggregation=aggregation)
-    self.evaluate(v.initializer)
-
-    delta = values.PerReplica([
-        indexed_slices.IndexedSlices(
-            values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
-        indexed_slices.IndexedSlices(
-            values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
-    ])
-
-    with self.assertRaises(NotImplementedError):
-      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
-
-
-def _make_index_slices(vals, indices, dense_shape=None):
-  if dense_shape:
-    dense_shape = array_ops.identity(dense_shape)
-  return indexed_slices.IndexedSlices(
-      array_ops.identity(vals), array_ops.identity(indices), dense_shape)
-
-
-if __name__ == "__main__":
-  test.main()