Group variable initialization when calling lift_to_graph.
When initializing variables defined inside a @tf.function which are lifted to the outer graph, group the variables together and call lift_to_graph once. lift_to_graph supports passing in multiple tensors and the graph to lift to is the same for all of the variable initialization. This improves setup time.
PiperOrigin-RevId: 284263511
Change-Id: I4cfcdb0394198df8f890a98295cc2fcb77b75413
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 2f20179..2c1ced2 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -728,13 +728,20 @@
resource_variable_ops.var_is_initialized_op(v.handle))
var_is_initialized = array_ops.stack(var_is_initialized).numpy()
+ inits = []
for (v, init), is_initialized in zip(initializers, var_is_initialized):
with ops.init_scope():
if is_initialized:
continue
+ inits.append(init)
+ if inits:
op_map = lift_to_graph.lift_to_graph(
- [init], ops.get_default_graph(), op_map=op_map)
+ inits, ops.get_default_graph(), op_map=op_map)
+ for (v, init), is_initialized in zip(initializers, var_is_initialized):
+ with ops.init_scope():
+ if is_initialized:
+ continue
v.assign(op_map[init], read_value=False)
with ops.init_scope():
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index dca257f..b558412 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -137,6 +137,19 @@
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ def testFunctionMultipleVariableInitializer(self):
+
+ state = []
+
+ @def_function.function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(lambda: 2.0))
+ state.append(variables.Variable(lambda: 5.0))
+ return state[0] * x, state[1] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
+
def testFunctionInitializationFunction(self):
state = []