Group variable initialization when calling lift_to_graph.

When initializing variables defined inside a @tf.function which are lifted to the outer graph, group the variables together and call lift_to_graph once.  lift_to_graph supports passing in multiple tensors and the graph to lift to is the same for all of the variable initialization.  This improves setup time.

PiperOrigin-RevId: 284263511
Change-Id: I4cfcdb0394198df8f890a98295cc2fcb77b75413
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 2f20179..2c1ced2 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -728,13 +728,20 @@
               resource_variable_ops.var_is_initialized_op(v.handle))
         var_is_initialized = array_ops.stack(var_is_initialized).numpy()
 
+      inits = []
       for (v, init), is_initialized in zip(initializers, var_is_initialized):
         with ops.init_scope():
           if is_initialized:
             continue
+        inits.append(init)
 
+      if inits:
         op_map = lift_to_graph.lift_to_graph(
-            [init], ops.get_default_graph(), op_map=op_map)
+            inits, ops.get_default_graph(), op_map=op_map)
+      for (v, init), is_initialized in zip(initializers, var_is_initialized):
+        with ops.init_scope():
+          if is_initialized:
+            continue
         v.assign(op_map[init], read_value=False)
 
     with ops.init_scope():
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index dca257f..b558412 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -137,6 +137,19 @@
 
     self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
 
+  def testFunctionMultipleVariableInitializer(self):
+
+    state = []
+
+    @def_function.function
+    def fn(x):
+      if not state:
+        state.append(variables.Variable(lambda: 2.0))
+        state.append(variables.Variable(lambda: 5.0))
+      return state[0] * x, state[1] * x
+
+    self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
+
   def testFunctionInitializationFunction(self):
 
     state = []