Correctly tag forward and backward func graphs on while loops again
PiperOrigin-RevId: 338374019
Change-Id: I0533afa733b81dea1d1a09fb034c2c9241e378ae
diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py
index 046897d..e19dec8 100644
--- a/tensorflow/python/ops/control_flow_util_v2.py
+++ b/tensorflow/python/ops/control_flow_util_v2.py
@@ -189,10 +189,10 @@
output_idx = int(output_idx)
node_def = node_defs[op_name]
- if node_def.op in ("Identity", "While"):
+ if node_def.op == "While":
# Captured resources occur at the same index in the lists of inputs and
- # outputs of a while or identity op. So we lookup the input of `tensor.op`
- # at the same index as the index of `tensor` in the `tensor.op.outputs`.
+ # outputs of a while op. So we lookup the input of `tensor.op` at the
+ # same index as the index of `tensor` in the `tensor.op.outputs`.
tensor_name = node_def.input[output_idx]
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
# Functions output any captured resource tensors used by their
@@ -357,3 +357,4 @@
return results
else:
return make_op(inputs)
+
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index f5ed047..9e773f4 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -276,9 +276,12 @@
body_graph,
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
- name=scope,
- attach_func_graphs=True,
- num_original_outputs=num_original_outputs)
+ name=scope)
+ # This is needed so we do not compute derivative wrt these extra outputs.
+ outputs[0].op._set_attr("_num_original_outputs",
+ attr_value_pb2.AttrValue(i=num_original_outputs))
+ outputs[0].op._cond_graph = cond_graph
+ outputs[0].op._body_graph = body_graph
if not ops.get_default_graph().building_function:
# In V1 graph mode, return identities for each output of the While op,
# rather than the output of the While op directly. This makes pruning work
@@ -405,11 +408,7 @@
body_grad_graph,
output_shapes=[t.shape for t in body_grad_graph.outputs],
parallel_iterations=parallel_iterations,
- name="%s_grad" % while_op.name,
- # TODO(allenl): It seems like attach_func_graphs=True should work, but
- # historically we haven't attached them here and it appears to break some
- # tests.
- attach_func_graphs=False)
+ name="%s_grad" % while_op.name)
# See comment in while_loop.
outputs = [array_ops.identity(t) for t in outputs]
@@ -417,8 +416,7 @@
def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
- parallel_iterations, name, attach_func_graphs,
- num_original_outputs=None):
+ parallel_iterations, name):
"""Builds the functional StatelessWhile/While op."""
cond_stateful_ops = [
op for op in cond_graph.get_operations() if op._is_stateful
@@ -443,18 +441,6 @@
util.maybe_set_lowering_attr(while_op)
util.maybe_propagate_compile_time_consts_in_xla(while_op)
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
- # This is needed so we do not compute derivative wrt these extra outputs.
- if num_original_outputs is not None:
- while_op._set_attr("_num_original_outputs",
- attr_value_pb2.AttrValue(i=num_original_outputs))
- if attach_func_graphs:
- # The while op may be created inside a tf.function, in which case ops
- # needs to capture "through" it when taking gradients; outer_graph is used
- # as a sanity check that capturing only happens from parent to child.
- cond_graph.outer_graph = ops.get_default_graph()
- body_graph.outer_graph = ops.get_default_graph()
- while_op._cond_graph = cond_graph
- while_op._body_graph = body_graph
return tensors
return util.run_as_function_for_tape_gradients(_make_op, loop_vars)