Add read variable op to function control outputs

After inlining, we convert control dependencies to the function call to depend
on all control outputs. Although read variable ops doesn't have side effect
itself, it must be executed before variable updates after the function call.

Note that it's not enough to convert control dependencies to the function call
to depend on both control outputs and data outputs of the function call. The
read variable ops can be inputs to ops that have side effects, e.g. assert and
print, which are not the function data outputs.

PiperOrigin-RevId: 338391537
Change-Id: Ibc843c1a584088e54010685ebce678b047f7d94d
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index bd7a506..7ebcf77 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -4383,6 +4383,47 @@
 
     self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
 
+  @test_util.run_v2_only
+  def testControlDependencyAfterInline(self):
+    v = variables.Variable(0.)
+
+    @def_function.function
+    def assign():
+      return v.assign(1.)
+
+    @def_function.function
+    def assign_add():
+      return v.assign_add(1.)
+
+    @def_function.function
+    def f():
+      check_ops.assert_equal_v2(assign(), 1.)
+      check_ops.assert_equal_v2(assign_add(), 2.)
+
+    # We don't have a way to inspect the inlined graph in Python, so we run it
+    # multiple times to have more confidence the dependency is correct.
+    for _ in range(30):
+      f()
+
+  @test_util.run_v2_only
+  def testReadInFuncWriteOutside(self):
+    # Run many times since we are testing for a potential race condition.
+    for _ in range(30):
+      # pylint: disable=cell-var-from-loop
+      v = variables.Variable(1.)
+
+      @def_function.function
+      def add_one():
+        return v + 1.
+
+      @def_function.function
+      def get_v_plus_one():
+        v_plus_one = add_one()
+        v.assign_add(2.0)
+        return v_plus_one
+
+      self.assertAllEqual(get_v_plus_one(), 2.0)
+
 
 class MultiDeviceTest(test.TestCase, parameterized.TestCase):
 
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index 12f35a4..3f6ab98 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -373,11 +373,14 @@
       if control_flow_util.IsInWhileLoop(op):
         continue
       control_inputs = set()
-      # Ensure stateful ops run
+      # Ensure stateful ops run.
+      # Read-only ops are added to control outputs if the read value is
+      # consumed. This covers the case when the read value is returned from
+      # the function since that goes through a tf.identity in mark_as_return.
       if (op_def_registry.get(op.type) is None or
-          (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)):
-        # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
-        # they only have variable reads and are otherwise stateless.
+          (op_is_stateful(op) and
+           (op.type not in utils.RESOURCE_READ_OPS or
+            any(output.consumers() for output in op.outputs)))):
         ops_which_must_run.add(op)
       # Make a note of all opened manager_ids.
       if op.type == "NoOp":
diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py
index dc5d898..a7b238c 100644
--- a/tensorflow/python/framework/auto_control_deps_test.py
+++ b/tensorflow/python/framework/auto_control_deps_test.py
@@ -102,23 +102,16 @@
       self.assertNotIn(read_op1, read_op2.control_inputs)
       self.assertNotIn(read_op2, read_op1.control_inputs)
 
-  def testVariableReadsNotInOpsWithMustRun(self):
+  def testVariableReadsInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       self.evaluate(variables.global_variables_initializer())
       with acd.AutomaticControlDependencies() as c:
-        read_op1 = gen_resource_variable_ops.read_variable_op(
-            v.handle, v.dtype).op
-        read_op2 = gen_resource_variable_ops.read_variable_op(
-            v.handle, v.dtype).op
-        assign_op = gen_resource_variable_ops.assign_variable_op(
-            v.handle, v + 1)
-      # Reads must not be in `ops_which_must_run` since those get added to the
-      # `control_outputs`.
-      self.assertNotIn(read_op1, c.ops_which_must_run)
-      self.assertNotIn(read_op2, c.ops_which_must_run)
-      # Last write must be in `ops_which_must_run`.
-      self.assertIn(assign_op, c.ops_which_must_run)
+        read_op = gen_resource_variable_ops.read_variable_op(v.handle,
+                                                             v.dtype).op
+        # Read ops get added to control outputs only if they have consumers.
+        c.mark_as_return(read_op.outputs[0])
+      self.assertIn(read_op, c.ops_which_must_run)
 
   def testVariableMultipleReadsAndWrites(self):
     with context.graph_mode(), self.cached_session():
@@ -142,6 +135,11 @@
             v.handle, v + 1)
         assign_op4 = gen_resource_variable_ops.assign_variable_op(
             v.handle, v + 1)
+        # Read ops get added to control outputs only if they have consumers.
+        c.mark_as_return(read_op1.outputs[0])
+        c.mark_as_return(read_op2.outputs[0])
+        c.mark_as_return(read_op3.outputs[0])
+        c.mark_as_return(read_op4.outputs[0])
 
       # Verify the control edges.
       self.assertIn(read_op1, assign_op1.control_inputs)
@@ -158,11 +156,11 @@
       for src_op, tgt_op in itertools.product(read_ops, read_ops):
         self.assertNotIn(src_op, tgt_op.control_inputs)
 
-      # Reads must not be in `ops_which_must_run`.
-      self.assertNotIn(read_op1, c.ops_which_must_run)
-      self.assertNotIn(read_op2, c.ops_which_must_run)
-      self.assertNotIn(read_op3, c.ops_which_must_run)
-      self.assertNotIn(read_op4, c.ops_which_must_run)
+      # Reads must be in `ops_which_must_run`.
+      self.assertIn(read_op1, c.ops_which_must_run)
+      self.assertIn(read_op2, c.ops_which_must_run)
+      self.assertIn(read_op3, c.ops_which_must_run)
+      self.assertIn(read_op4, c.ops_which_must_run)
       # Last write must be in `ops_which_must_run`.
       self.assertIn(assign_op4, c.ops_which_must_run)
 
diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py
index 3336d3f..b5a9e7c 100644
--- a/tensorflow/python/grappler/constant_folding_test.py
+++ b/tensorflow/python/grappler/constant_folding_test.py
@@ -96,19 +96,15 @@
         f(x, y).numpy()
     self.assertLen(graphs, 1)
     assign_count = 0
-    read_count = 0
     for node in graphs[0].node:
       if node.op == 'AssignAddVariableOp':
         self.assertEqual(node.input[0], 'y')
         assign_count += 1
-      if node.op == 'ReadVariableOp':
-        read_count += 1
 
     # Make sure that the only variable update that remains after
-    # grappler optimization is that of y, and that we prune all
-    # but the 2 necessary variable reads.
+    # grappler optimization is that of y.
     self.assertEqual(assign_count, 1)
-    self.assertEqual(read_count, 2)
+    self.assertLen(graphs[0].node, 11)
 
 
 if __name__ == '__main__':