Correctly tag forward and backward func graphs on while loops again

PiperOrigin-RevId: 338374019
Change-Id: I0533afa733b81dea1d1a09fb034c2c9241e378ae
diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py
index 046897d..e19dec8 100644
--- a/tensorflow/python/ops/control_flow_util_v2.py
+++ b/tensorflow/python/ops/control_flow_util_v2.py
@@ -189,10 +189,10 @@
     output_idx = int(output_idx)
     node_def = node_defs[op_name]
 
-    if node_def.op in ("Identity", "While"):
+    if node_def.op == "While":
       # Captured resources occur at the same index in the lists of inputs and
-      # outputs of a while or identity op. So we lookup the input of `tensor.op`
-      # at the same index as the index of `tensor` in the `tensor.op.outputs`.
+      # outputs of a while op. So we lookup the input of `tensor.op` at the
+      # same index as the index of `tensor` in the `tensor.op.outputs`.
       tensor_name = node_def.input[output_idx]
     elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
       # Functions output any captured resource tensors used by their
@@ -357,3 +357,4 @@
     return results
   else:
     return make_op(inputs)
+
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index f5ed047..9e773f4 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -276,9 +276,12 @@
           body_graph,
           output_shapes=output_shapes,
           parallel_iterations=parallel_iterations,
-          name=scope,
-          attach_func_graphs=True,
-          num_original_outputs=num_original_outputs)
+          name=scope)
+      # This is needed so we do not compute derivative wrt these extra outputs.
+      outputs[0].op._set_attr("_num_original_outputs",
+                              attr_value_pb2.AttrValue(i=num_original_outputs))
+    outputs[0].op._cond_graph = cond_graph
+    outputs[0].op._body_graph = body_graph
     if not ops.get_default_graph().building_function:
       # In V1 graph mode, return identities for each output of the While op,
       # rather than the output of the While op directly. This makes pruning work
@@ -405,11 +408,7 @@
       body_grad_graph,
       output_shapes=[t.shape for t in body_grad_graph.outputs],
       parallel_iterations=parallel_iterations,
-      name="%s_grad" % while_op.name,
-      # TODO(allenl): It seems like attach_func_graphs=True should work, but
-      # historically we haven't attached them here and it appears to break some
-      # tests.
-      attach_func_graphs=False)
+      name="%s_grad" % while_op.name)
 
   # See comment in while_loop.
   outputs = [array_ops.identity(t) for t in outputs]
@@ -417,8 +416,7 @@
 
 
 def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
-                    parallel_iterations, name, attach_func_graphs,
-                    num_original_outputs=None):
+                    parallel_iterations, name):
   """Builds the functional StatelessWhile/While op."""
   cond_stateful_ops = [
       op for op in cond_graph.get_operations() if op._is_stateful
@@ -443,18 +441,6 @@
     util.maybe_set_lowering_attr(while_op)
     util.maybe_propagate_compile_time_consts_in_xla(while_op)
     _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
-    # This is needed so we do not compute derivative wrt these extra outputs.
-    if num_original_outputs is not None:
-      while_op._set_attr("_num_original_outputs",
-                         attr_value_pb2.AttrValue(i=num_original_outputs))
-    if attach_func_graphs:
-      # The while op may be created inside a tf.function, in which case ops
-      # needs to capture "through" it when taking gradients; outer_graph is used
-      # as a sanity check that capturing only happens from parent to child.
-      cond_graph.outer_graph = ops.get_default_graph()
-      body_graph.outer_graph = ops.get_default_graph()
-      while_op._cond_graph = cond_graph
-      while_op._body_graph = body_graph
     return tensors
   return util.run_as_function_for_tape_gradients(_make_op, loop_vars)