Move `create_mirrored_variable` into `values.py`.
PiperOrigin-RevId: 273520996
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index a13c7da..0bfe4e2 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -111,7 +111,6 @@
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context as eager_context
from tensorflow.python.eager import monitoring
-from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -2273,89 +2272,6 @@
distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access
-def create_mirrored_variable( # pylint: disable=missing-docstring
- strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
- sync_on_read_cls, *args, **kwargs):
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- collections = kwargs.pop("collections", None)
- if collections is None:
- collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
- synchronization = kwargs.get(
- "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
-
- if synchronization == variable_scope.VariableSynchronization.NONE:
- raise ValueError(
- "`NONE` variable synchronization mode is not supported with `Mirrored` "
- "distribution strategy. Please change the `synchronization` for "
- "variable: " + kwargs["name"])
- elif synchronization == variable_scope.VariableSynchronization.ON_READ:
- is_sync_on_read = True
- elif synchronization in (
- variable_scope.VariableSynchronization.ON_WRITE,
- variable_scope.VariableSynchronization.AUTO):
- # `AUTO` synchronization defaults to `ON_WRITE`.
- is_sync_on_read = False
- else:
- raise ValueError(
- "Invalid variable synchronization mode: %s for variable: %s" %
- (synchronization, kwargs["name"]))
-
- aggregation = kwargs.pop(
- "aggregation", variable_scope.VariableAggregation.NONE)
-
- if aggregation not in (
- variable_scope.VariableAggregation.NONE,
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN,
- variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
- raise ValueError(
- "Invalid variable aggregation mode: %s for variable: %s" %
- (aggregation, kwargs["name"]))
-
- # Ignore user-specified caching device, not needed for mirrored variables.
- kwargs.pop("caching_device", None)
-
- # TODO(josh11b,apassos): It would be better if variable initialization
- # was never recorded on the tape instead of having to do this manually
- # here.
- with tape.stop_recording():
- devices = device_map.logical_to_actual_devices(logical_device)
- value_list = real_mirrored_creator(devices, *args, **kwargs)
-
- var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
-
- result = var_cls(
- strategy, device_map, value_list, aggregation,
- logical_device=logical_device)
-
- # Add the wrapped variable to the requested collections.
- # The handling of eager mode and the global step matches
- # ResourceVariable._init_from_args().
- if not eager_context.executing_eagerly():
- g = ops.get_default_graph()
- # If "trainable" is True, next_creator() will add the member variables
- # to the TRAINABLE_VARIABLES collection, so we manually remove
- # them and replace with the MirroredVariable. We can't set
- # "trainable" to False for next_creator() since that causes functions
- # like implicit_gradients to skip those variables.
- if kwargs.get("trainable", True):
- collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
- l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
- for value in value_list:
- for i, trainable_variable in enumerate(l):
- if value is trainable_variable:
- del l[i]
- break
-
- g.add_to_collections(collections, result)
- elif ops.GraphKeys.GLOBAL_STEP in collections:
- ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
-
- return result
-
# ------------------------------------------------------------------------------
# Metrics to track which distribution strategy is being called
distribution_strategy_gauge = monitoring.StringGauge(
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index f8ddb8f..8c90af0 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -536,7 +536,7 @@
value_list.append(v)
return value_list
- return distribute_lib.create_mirrored_variable(
+ return values.create_mirrored_variable(
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, values.MirroredVariable,
values.SyncOnReadVariable, *args, **kwargs)
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 0073f3c..3d91d2e 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -429,7 +429,7 @@
value_list.append(v)
return value_list
- return distribute_lib.create_mirrored_variable(
+ return values.create_mirrored_variable(
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, values.TPUMirroredVariable,
values.TPUSyncOnReadVariable, *args, **kwargs)
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index cafd68a..1af39d9 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -977,6 +977,89 @@
for v in self._mirrored_variable.values))
+def create_mirrored_variable( # pylint: disable=missing-docstring
+ strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
+ sync_on_read_cls, *args, **kwargs):
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ var_collections = kwargs.pop("collections", None)
+ if var_collections is None:
+ var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ synchronization = kwargs.get(
+ "synchronization", vs.VariableSynchronization.ON_WRITE)
+
+ if synchronization == vs.VariableSynchronization.NONE:
+ raise ValueError(
+ "`NONE` variable synchronization mode is not supported with `Mirrored` "
+ "distribution strategy. Please change the `synchronization` for "
+ "variable: " + kwargs["name"])
+ elif synchronization == vs.VariableSynchronization.ON_READ:
+ is_sync_on_read = True
+ elif synchronization in (
+ vs.VariableSynchronization.ON_WRITE,
+ vs.VariableSynchronization.AUTO):
+ # `AUTO` synchronization defaults to `ON_WRITE`.
+ is_sync_on_read = False
+ else:
+ raise ValueError(
+ "Invalid variable synchronization mode: %s for variable: %s" %
+ (synchronization, kwargs["name"]))
+
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+
+ if aggregation not in (
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_REPLICA):
+ raise ValueError(
+ "Invalid variable aggregation mode: %s for variable: %s" %
+ (aggregation, kwargs["name"]))
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ devices = device_map.logical_to_actual_devices(logical_device)
+ value_list = real_mirrored_creator(devices, *args, **kwargs)
+
+ var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
+
+ result = var_cls(
+ strategy, device_map, value_list, aggregation,
+ logical_device=logical_device)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for value in value_list:
+ for i, trainable_variable in enumerate(l):
+ if value is trainable_variable:
+ del l[i]
+ break
+
+ g.add_to_collections(var_collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in var_collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
+ return result
+
+
class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync."""