Fix the bug that SyncOnReadVariable is not expanded while converted to PerReplica
PiperOrigin-RevId: 444666215
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index ee2d7f8..5a6fe30 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -51,6 +51,7 @@
"//tensorflow/python/client:device_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
+ "//tensorflow/python/framework:indexed_slices",
"//tensorflow/python/util",
"//tensorflow/python/util:tf_export",
"//tensorflow/tools/docs:doc_controls",
@@ -1331,14 +1332,14 @@
srcs = ["strategy_test_lib.py"],
srcs_version = "PY3",
deps = [
+ ":collective_all_reduce_strategy",
":distribute_lib",
":distribute_utils",
+ ":mirrored_strategy",
":reduce_util",
":tpu_strategy",
- ":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 3ffe2d1..6d8122e 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -117,7 +117,7 @@
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
"got %r but expected a object that is not a tuple or list."
% (input_tensor,))
- if isinstance(input_tensor, value_lib.PerReplica):
+ if isinstance(input_tensor, value_lib.DistributedValues):
return input_tensor
# If input is not a Tensor, convert it to a Tensor first.
@@ -303,8 +303,8 @@
"""
if options is None:
options = collective_util.Options()
- if not isinstance(per_replica_value, value_lib.DistributedValues):
- per_replica_value = _make_tensor_into_per_replica(per_replica_value)
+
+ per_replica_value = _make_tensor_into_per_replica(per_replica_value)
validate_destinations(destinations)
@@ -352,8 +352,7 @@
if options is None:
options = collective_util.Options()
- if not isinstance(per_replica_value, value_lib.DistributedValues):
- per_replica_value = _make_tensor_into_per_replica(per_replica_value)
+ per_replica_value = _make_tensor_into_per_replica(per_replica_value)
validate_destinations(destinations)
diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py
index 2cdea45..36ad1aa 100644
--- a/tensorflow/python/distribute/strategy_common_test.py
+++ b/tensorflow/python/distribute/strategy_common_test.py
@@ -403,6 +403,51 @@
nest.map_structure(ops.convert_to_tensor, got),
nest.map_structure(ops.convert_to_tensor, expect))
+ def testSyncOnReadVariableInput(self, strategy, tf_function):
+ if (not strategy_test_lib.is_mirrored_strategy(strategy) and
+ not strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
+ not strategy_test_lib.is_tpu_strategy(strategy)):
+ self.skipTest('Skip strategies not using SyncOnReadVariables.')
+ if (strategy_test_lib.is_tpu_strategy(strategy) and
+ tf_function is combinations.no_tf_function):
+ self.skipTest('Skip TPUStrategy + eager combination.')
+ if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
+ tf_function is combinations.tf_function):
+ self.skipTest('Skip MWMS + graph combination until b/228512201 is fixed.')
+
+ with strategy.scope():
+ var = variables.Variable(
+ 0.0,
+ synchronization=variables.VariableSynchronization.ON_READ,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
+
+ @tf_function
+ def replica_fn():
+ replica_context = ds_context.get_replica_context()
+ replica_id = replica_context.replica_id_in_sync_group
+ var.assign(math_ops.cast(replica_id, dtype=float) * 3.0)
+
+ return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var)
+
+ if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy):
+ client_local_replica_num = strategy.extended._num_devices_per_worker
+ else:
+ client_local_replica_num = strategy.num_replicas_in_sync
+
+ workers_num = strategy.num_replicas_in_sync
+ expected_sum = sum(range(workers_num)) * 3.0
+
+ # Expand the values on each replica if multiple devices are used; otherwise
+ # simple read the value of the Tensor.
+ result = strategy.run(replica_fn)
+ if hasattr(result, 'values'):
+ result = result.values
+ result = nest.flatten(result)
+
+ # Iterate through all replicas and verify the reduce sum result.
+ for i in range(client_local_replica_num):
+ self.assertEqual(result[i].numpy(), expected_sum)
+
@combinations.generate(
combinations.combine(
diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py
index 47c9489..25f3f6f 100644
--- a/tensorflow/python/distribute/strategy_test_lib.py
+++ b/tensorflow/python/distribute/strategy_test_lib.py
@@ -24,9 +24,11 @@
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
+from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import backprop
@@ -136,6 +138,18 @@
return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
+def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool:
+ return isinstance(
+ strategy,
+ (mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1))
+
+
+def is_multi_worker_mirrored_strategy(
+ strategy: distribute_lib.Strategy) -> bool:
+ return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy,
+ mwms_lib.CollectiveAllReduceStrategyV1))
+
+
def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool:
return isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,