Add optional `input_tensors` argument to `Operation._control_flow_post_processing()`.
Currently, this method causes the `inputs` property of an `Operation` to be eagerly evaluated via the C API, which can be expensive for large ops (such as tf.while_loop() ops with a large number of loop variables). Since in most cases we already possess a list of tensors from the `Operation.__init__()` `inputs` argument, we can avoid evaluating `inputs`.
PiperOrigin-RevId: 268051042
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 113936a..2db726b 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1804,15 +1804,22 @@
self._graph._add_op(self, self._id_value, name) # pylint: disable=protected-access
if not c_op:
- self._control_flow_post_processing()
+ self._control_flow_post_processing(input_tensors=inputs)
- def _control_flow_post_processing(self):
+ def _control_flow_post_processing(self, input_tensors=None):
"""Add this op to its control flow context.
This may add new ops and change this op's inputs. self.inputs must be
available before calling this method.
+
+ Args:
+ input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
+ of this op, which should be equivalent to `self.inputs`. Pass this
+ argument to avoid evaluating `self.inputs` unnecessarily.
"""
- for input_tensor in self.inputs:
+ if input_tensors is None:
+ input_tensors = self.inputs
+ for input_tensor in input_tensors:
control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
if self._control_flow_context is not None:
self._control_flow_context.AddOp(self)