Move `create_mirrored_variable` into `values.py`.

PiperOrigin-RevId: 273520996
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index a13c7da..0bfe4e2 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -111,7 +111,6 @@
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.eager import context as eager_context
 from tensorflow.python.eager import monitoring
-from tensorflow.python.eager import tape
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -2273,89 +2272,6 @@
     distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
 
 
-def create_mirrored_variable(  # pylint: disable=missing-docstring
-    strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
-    sync_on_read_cls, *args, **kwargs):
-  # Figure out what collections this variable should be added to.
-  # We'll add the MirroredVariable to those collections instead.
-  collections = kwargs.pop("collections", None)
-  if collections is None:
-    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
-  kwargs["collections"] = []
-
-  synchronization = kwargs.get(
-      "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
-
-  if synchronization == variable_scope.VariableSynchronization.NONE:
-    raise ValueError(
-        "`NONE` variable synchronization mode is not supported with `Mirrored` "
-        "distribution strategy. Please change the `synchronization` for "
-        "variable: " + kwargs["name"])
-  elif synchronization == variable_scope.VariableSynchronization.ON_READ:
-    is_sync_on_read = True
-  elif synchronization in (
-      variable_scope.VariableSynchronization.ON_WRITE,
-      variable_scope.VariableSynchronization.AUTO):
-    # `AUTO` synchronization defaults to `ON_WRITE`.
-    is_sync_on_read = False
-  else:
-    raise ValueError(
-        "Invalid variable synchronization mode: %s for variable: %s" %
-        (synchronization, kwargs["name"]))
-
-  aggregation = kwargs.pop(
-      "aggregation", variable_scope.VariableAggregation.NONE)
-
-  if aggregation not in (
-      variable_scope.VariableAggregation.NONE,
-      variable_scope.VariableAggregation.SUM,
-      variable_scope.VariableAggregation.MEAN,
-      variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
-    raise ValueError(
-        "Invalid variable aggregation mode: %s for variable: %s" %
-        (aggregation, kwargs["name"]))
-
-  # Ignore user-specified caching device, not needed for mirrored variables.
-  kwargs.pop("caching_device", None)
-
-  # TODO(josh11b,apassos): It would be better if variable initialization
-  # was never recorded on the tape instead of having to do this manually
-  # here.
-  with tape.stop_recording():
-    devices = device_map.logical_to_actual_devices(logical_device)
-    value_list = real_mirrored_creator(devices, *args, **kwargs)
-
-    var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
-
-    result = var_cls(
-        strategy, device_map, value_list, aggregation,
-        logical_device=logical_device)
-
-  # Add the wrapped variable to the requested collections.
-  # The handling of eager mode and the global step matches
-  # ResourceVariable._init_from_args().
-  if not eager_context.executing_eagerly():
-    g = ops.get_default_graph()
-    # If "trainable" is True, next_creator() will add the member variables
-    # to the TRAINABLE_VARIABLES collection, so we manually remove
-    # them and replace with the MirroredVariable. We can't set
-    # "trainable" to False for next_creator() since that causes functions
-    # like implicit_gradients to skip those variables.
-    if kwargs.get("trainable", True):
-      collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
-      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
-      for value in value_list:
-        for i, trainable_variable in enumerate(l):
-          if value is trainable_variable:
-            del l[i]
-            break
-
-    g.add_to_collections(collections, result)
-  elif ops.GraphKeys.GLOBAL_STEP in collections:
-    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
-
-  return result
-
 # ------------------------------------------------------------------------------
 # Metrics to track which distribution strategy is being called
 distribution_strategy_gauge = monitoring.StringGauge(
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index f8ddb8f..8c90af0 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -536,7 +536,7 @@
           value_list.append(v)
       return value_list
 
-    return distribute_lib.create_mirrored_variable(
+    return values.create_mirrored_variable(
         self._container_strategy(), device_map, logical_device,
         _real_mirrored_creator, values.MirroredVariable,
         values.SyncOnReadVariable, *args, **kwargs)
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 0073f3c..3d91d2e 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -429,7 +429,7 @@
           value_list.append(v)
       return value_list
 
-    return distribute_lib.create_mirrored_variable(
+    return values.create_mirrored_variable(
         self._container_strategy(), device_map, logical_device,
         _real_mirrored_creator, values.TPUMirroredVariable,
         values.TPUSyncOnReadVariable, *args, **kwargs)
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index cafd68a..1af39d9 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -977,6 +977,89 @@
         for v in self._mirrored_variable.values))
 
 
+def create_mirrored_variable(  # pylint: disable=missing-docstring
+    strategy, device_map, logical_device, real_mirrored_creator, mirrored_cls,
+    sync_on_read_cls, *args, **kwargs):
+  # Figure out what collections this variable should be added to.
+  # We'll add the MirroredVariable to those collections instead.
+  var_collections = kwargs.pop("collections", None)
+  if var_collections is None:
+    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+  kwargs["collections"] = []
+
+  synchronization = kwargs.get(
+      "synchronization", vs.VariableSynchronization.ON_WRITE)
+
+  if synchronization == vs.VariableSynchronization.NONE:
+    raise ValueError(
+        "`NONE` variable synchronization mode is not supported with `Mirrored` "
+        "distribution strategy. Please change the `synchronization` for "
+        "variable: " + kwargs["name"])
+  elif synchronization == vs.VariableSynchronization.ON_READ:
+    is_sync_on_read = True
+  elif synchronization in (
+      vs.VariableSynchronization.ON_WRITE,
+      vs.VariableSynchronization.AUTO):
+    # `AUTO` synchronization defaults to `ON_WRITE`.
+    is_sync_on_read = False
+  else:
+    raise ValueError(
+        "Invalid variable synchronization mode: %s for variable: %s" %
+        (synchronization, kwargs["name"]))
+
+  aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+
+  if aggregation not in (
+      vs.VariableAggregation.NONE,
+      vs.VariableAggregation.SUM,
+      vs.VariableAggregation.MEAN,
+      vs.VariableAggregation.ONLY_FIRST_REPLICA):
+    raise ValueError(
+        "Invalid variable aggregation mode: %s for variable: %s" %
+        (aggregation, kwargs["name"]))
+
+  # Ignore user-specified caching device, not needed for mirrored variables.
+  kwargs.pop("caching_device", None)
+
+  # TODO(josh11b,apassos): It would be better if variable initialization
+  # was never recorded on the tape instead of having to do this manually
+  # here.
+  with tape.stop_recording():
+    devices = device_map.logical_to_actual_devices(logical_device)
+    value_list = real_mirrored_creator(devices, *args, **kwargs)
+
+    var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
+
+    result = var_cls(
+        strategy, device_map, value_list, aggregation,
+        logical_device=logical_device)
+
+  # Add the wrapped variable to the requested collections.
+  # The handling of eager mode and the global step matches
+  # ResourceVariable._init_from_args().
+  if not context.executing_eagerly():
+    g = ops.get_default_graph()
+    # If "trainable" is True, next_creator() will add the member variables
+    # to the TRAINABLE_VARIABLES collection, so we manually remove
+    # them and replace with the MirroredVariable. We can't set
+    # "trainable" to False for next_creator() since that causes functions
+    # like implicit_gradients to skip those variables.
+    if kwargs.get("trainable", True):
+      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+      for value in value_list:
+        for i, trainable_variable in enumerate(l):
+          if value is trainable_variable:
+            del l[i]
+            break
+
+    g.add_to_collections(var_collections, result)
+  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
+    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
+  return result
+
+
 class MirroredVariable(DistributedVariable, Mirrored):
   """Holds a map from replica to variables whose values are kept in sync."""