When recreating MirroredVariable during saved_model loading, set its component variables' initializers to no_op (other than None).
For some reason component variables are put into global variable collection, as a result stuff like `global_variables_initializer` will break if the component variable doesn't have an initializer.
PiperOrigin-RevId: 359637861
Change-Id: I164c8064241a2341ee4e271cef8d6ae98ce8d38f
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 7c9316a..e1e6dda 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -1541,16 +1541,26 @@
":collective_all_reduce_strategy",
":combinations",
":distribute_lib",
+ ":distribute_utils",
":strategy_combinations",
":values",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:config",
"//tensorflow/python:constant_op",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:func_graph",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:rnn_cell",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/saved_model:load",
+ "//tensorflow/python/saved_model:save",
+ "//tensorflow/python/training/tracking:util",
],
)
diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py
index 3a7840b..6df392e 100644
--- a/tensorflow/python/distribute/distribute_utils.py
+++ b/tensorflow/python/distribute/distribute_utils.py
@@ -304,6 +304,16 @@
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(**kwargs)
+ # MirroredVariable is recreated during saved_model loading, and its
+ # component variables (value_list) will have None initializer. We
+ # set their initializers to no_op so that consumer like
+ # `global_variables_initializer` wouldn't complain, as it groups all
+ # variables' initializers thus all variables have to have initializers.
+ for v in value_list:
+ # pylint:disable=protected-access
+ if v._initializer_op is None:
+ v._initializer_op = control_flow_ops.no_op()
+ # pylint:enable=protected-access
if use_var_policy:
var_policy_cls = policy_mapping.get(synchronization)
var_policy = var_policy_cls(aggregation=aggregation)
diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py
index 53a18fb..acf6861 100644
--- a/tensorflow/python/distribute/mirrored_variable_test.py
+++ b/tensorflow/python/distribute/mirrored_variable_test.py
@@ -39,6 +39,9 @@
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.saved_model import load
+from tensorflow.python.saved_model import save
+from tensorflow.python.training.tracking import util as tracking_util
def _replica_id():
@@ -590,6 +593,29 @@
self.assertIs(distribution, mirrored.distribute_strategy)
self.assertIs(distribution, sync_on_read.distribute_strategy)
+ def testInitializer(self, distribution, mode):
+ if mode == "graph":
+ self.skipTest("Skip graph mode")
+
+ temp_dir = self.get_temp_dir()
+
+ class Model(tracking_util.Checkpoint):
+
+ def __init__(self):
+ self._v = variables.Variable(1.0)
+
+ with distribution.scope():
+ m = Model()
+ save.save(m, temp_dir)
+
+ g = ops.Graph()
+ with g.as_default():
+ with distribution.scope():
+ load.load(temp_dir)
+
+ for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
+ self.assertIsNotNone(v.initializer)
+
if __name__ == "__main__":
test.main()