Fix the bug that SyncOnReadVariable is not expanded while converted to PerReplica

PiperOrigin-RevId: 444666215
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index ee2d7f8..5a6fe30 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -51,6 +51,7 @@
         "//tensorflow/python/client:device_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/framework:indexed_slices",
         "//tensorflow/python/util",
         "//tensorflow/python/util:tf_export",
         "//tensorflow/tools/docs:doc_controls",
@@ -1331,14 +1332,14 @@
     srcs = ["strategy_test_lib.py"],
     srcs_version = "PY3",
     deps = [
+        ":collective_all_reduce_strategy",
         ":distribute_lib",
         ":distribute_utils",
+        ":mirrored_strategy",
         ":reduce_util",
         ":tpu_strategy",
-        ":values",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 3ffe2d1..6d8122e 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -117,7 +117,7 @@
     raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
                      "got %r but expected a object that is not a tuple or list."
                      % (input_tensor,))
-  if isinstance(input_tensor, value_lib.PerReplica):
+  if isinstance(input_tensor, value_lib.DistributedValues):
     return input_tensor
 
   # If input is not a Tensor, convert it to a Tensor first.
@@ -303,8 +303,8 @@
     """
     if options is None:
       options = collective_util.Options()
-    if not isinstance(per_replica_value, value_lib.DistributedValues):
-      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
+
+    per_replica_value = _make_tensor_into_per_replica(per_replica_value)
 
     validate_destinations(destinations)
 
@@ -352,8 +352,7 @@
     if options is None:
       options = collective_util.Options()
 
-    if not isinstance(per_replica_value, value_lib.DistributedValues):
-      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
+    per_replica_value = _make_tensor_into_per_replica(per_replica_value)
 
     validate_destinations(destinations)
 
diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py
index 2cdea45..36ad1aa 100644
--- a/tensorflow/python/distribute/strategy_common_test.py
+++ b/tensorflow/python/distribute/strategy_common_test.py
@@ -403,6 +403,51 @@
         nest.map_structure(ops.convert_to_tensor, got),
         nest.map_structure(ops.convert_to_tensor, expect))
 
+  def testSyncOnReadVariableInput(self, strategy, tf_function):
+    if (not strategy_test_lib.is_mirrored_strategy(strategy) and
+        not strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
+        not strategy_test_lib.is_tpu_strategy(strategy)):
+      self.skipTest('Skip strategies not using SyncOnReadVariables.')
+    if (strategy_test_lib.is_tpu_strategy(strategy) and
+        tf_function is combinations.no_tf_function):
+      self.skipTest('Skip TPUStrategy + eager combination.')
+    if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
+        tf_function is combinations.tf_function):
+      self.skipTest('Skip MWMS + graph combination until b/228512201 is fixed.')
+
+    with strategy.scope():
+      var = variables.Variable(
+          0.0,
+          synchronization=variables.VariableSynchronization.ON_READ,
+          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
+
+    @tf_function
+    def replica_fn():
+      replica_context = ds_context.get_replica_context()
+      replica_id = replica_context.replica_id_in_sync_group
+      var.assign(math_ops.cast(replica_id, dtype=float) * 3.0)
+
+      return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var)
+
+    if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy):
+      client_local_replica_num = strategy.extended._num_devices_per_worker
+    else:
+      client_local_replica_num = strategy.num_replicas_in_sync
+
+    workers_num = strategy.num_replicas_in_sync
+    expected_sum = sum(range(workers_num)) * 3.0
+
+    # Expand the values on each replica if multiple devices are used; otherwise
+    # simple read the value of the Tensor.
+    result = strategy.run(replica_fn)
+    if hasattr(result, 'values'):
+      result = result.values
+    result = nest.flatten(result)
+
+    # Iterate through all replicas and verify the reduce sum result.
+    for i in range(client_local_replica_num):
+      self.assertEqual(result[i].numpy(), expected_sum)
+
 
 @combinations.generate(
     combinations.combine(
diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py
index 47c9489..25f3f6f 100644
--- a/tensorflow/python/distribute/strategy_test_lib.py
+++ b/tensorflow/python/distribute/strategy_test_lib.py
@@ -24,9 +24,11 @@
 from tensorflow.core.util import event_pb2
 from tensorflow.python.client import session as session_lib
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribute_utils
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
+from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.eager import backprop
@@ -136,6 +138,18 @@
   return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
 
 
+def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool:
+  return isinstance(
+      strategy,
+      (mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1))
+
+
+def is_multi_worker_mirrored_strategy(
+    strategy: distribute_lib.Strategy) -> bool:
+  return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy,
+                               mwms_lib.CollectiveAllReduceStrategyV1))
+
+
 def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool:
   return isinstance(strategy,
                     (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,