When recreating MirroredVariable during saved_model loading, set its component variables' initializers to no_op (other than None).

For some reason component variables are put into global variable collection, as a result stuff like `global_variables_initializer` will break if the component variable doesn't have an initializer.

PiperOrigin-RevId: 359637861
Change-Id: I164c8064241a2341ee4e271cef8d6ae98ce8d38f
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 7c9316a..e1e6dda 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1541,16 +1541,26 @@
         ":collective_all_reduce_strategy",
         ":combinations",
         ":distribute_lib",
+        ":distribute_utils",
         ":strategy_combinations",
         ":values",
-        "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python:config",
         "//tensorflow/python:constant_op",
-        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:func_graph",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:rnn_cell",
         "//tensorflow/python:state_ops",
         "//tensorflow/python:variable_scope",
+        "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:test",
+        "//tensorflow/python/saved_model:load",
+        "//tensorflow/python/saved_model:save",
+        "//tensorflow/python/training/tracking:util",
     ],
 )
 
diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py
index 3a7840b..6df392e 100644
--- a/tensorflow/python/distribute/distribute_utils.py
+++ b/tensorflow/python/distribute/distribute_utils.py
@@ -304,6 +304,16 @@
   # here.
   with tape.stop_recording():
     value_list = real_mirrored_creator(**kwargs)
+    # MirroredVariable is recreated during saved_model loading, and its
+    # component variables (value_list) will have None initializer. We
+    # set their initializers to no_op so that consumer like
+    # `global_variables_initializer` wouldn't complain, as it groups all
+    # variables' initializers thus all variables have to have initializers.
+    for v in value_list:
+      # pylint:disable=protected-access
+      if v._initializer_op is None:
+        v._initializer_op = control_flow_ops.no_op()
+      # pylint:enable=protected-access
     if use_var_policy:
       var_policy_cls = policy_mapping.get(synchronization)
       var_policy = var_policy_cls(aggregation=aggregation)
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index 53a18fb..acf6861 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -39,6 +39,9 @@
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
+from tensorflow.python.saved_model import load
+from tensorflow.python.saved_model import save
+from tensorflow.python.training.tracking import util as tracking_util
 
 
 def _replica_id():
@@ -590,6 +593,29 @@
       self.assertIs(distribution, mirrored.distribute_strategy)
       self.assertIs(distribution, sync_on_read.distribute_strategy)
 
+  def testInitializer(self, distribution, mode):
+    if mode == "graph":
+      self.skipTest("Skip graph mode")
+
+    temp_dir = self.get_temp_dir()
+
+    class Model(tracking_util.Checkpoint):
+
+      def __init__(self):
+        self._v = variables.Variable(1.0)
+
+    with distribution.scope():
+      m = Model()
+    save.save(m, temp_dir)
+
+    g = ops.Graph()
+    with g.as_default():
+      with distribution.scope():
+        load.load(temp_dir)
+
+      for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
+        self.assertIsNotNone(v.initializer)
+
 
 if __name__ == "__main__":
   test.main()