blob: 2e0d0a233f391470d383d421ee1f60ca67522343 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""while_v2 and gradient.
This is a version of while_loop that emits a single While op, as well as the
gradient function for While ops produced by while_loop. This will eventually
replace the current tf.while_loop implementation once it reaches feature and
performance parity.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util as util_v1
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import while_v2_indexed_slices_rewriter
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
# pylint: disable=protected-access
# Controls parallelism in the presence of side-effecting ops like variable
# operations, print, py_function, etc. Can be set to True, False, or
# "stateless_cond" (default).
# Note that loops without side-effecting operations always execute with maximum
# parallelism, ignoring this setting. When False, loops with side-effecting ops
# execute sequentially, one iteration at a time.
# When True, loops with side-effecting ops may execute parts of different
# iterations in parallel; caution: if the loop condition contains
# side-effecting ops, this mode produces unspecified results.
# Setting it to "stateless_cond" automatically sets this mode to True when
# the loop condition is free of side-effecting ops.
# TODO(b/152548567): Change this to "stateless_cond".
glob_stateful_parallelism = False
def while_loop(cond,
body,
loop_vars,
shape_invariants=None,
parallel_iterations=10,
maximum_iterations=None,
name=None,
return_same_structure=True,
back_prop=True):
"""Like tf.while_loop, except emits a single While op."""
# Keep the original loop_vars around to know which args were TensorArrays.
orig_loop_vars = loop_vars
# Cache its length since we use it at multiple places below.
len_orig_loop_vars = len(orig_loop_vars)
# Convert TensorArrays to their flow variables. These get converted back to
# TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
# `wrapped_body` below.
loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
loop_vars = nest.map_structure(
ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
expand_composites=True)
if shape_invariants is not None:
nest.assert_same_structure(orig_loop_vars, shape_invariants,
expand_composites=False)
signature = nest.map_structure(
control_flow_ops._shape_invariant_to_type_spec, loop_vars,
list(shape_invariants), expand_composites=False)
shape_invariants = nest.map_structure(
control_flow_ops._get_shape_invariant, loop_vars,
list(shape_invariants), expand_composites=False)
else:
signature = nest.map_structure(
type_spec.type_spec_from_value, loop_vars, expand_composites=False)
shape_invariants = nest.map_structure(
control_flow_ops._get_shape_invariant, loop_vars,
expand_composites=False)
if not name:
name = "while"
with ops.name_scope(name) as scope:
with ops.name_scope(None):
cond_name = util.unique_fn_name(scope, "cond")
body_name = util.unique_fn_name(scope, "body")
maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
maximum_iterations)
loop_counter = constant_op.constant(
0,
dtype=maximum_iterations_loop_var.dtype
if maximum_iterations is not None else None,
name="loop_counter")
# Add loop counter needed for computing gradients.
loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars
shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
signature = (
[tensor_spec.TensorSpec.from_tensor(loop_counter),
tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
signature)
# Automatic control dependencies are added in defuns, but not in v1
# graphs. Propagate that behavior here.
add_control_dependencies = ops.get_default_graph()._add_control_dependencies
def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
"""Extra `cond` wrapper that can handle the extra counter loop_var."""
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
pred = cond(*_pack_sequence_as(orig_loop_vars, args))
if (tensor_util.is_tf_type(pred) and
(pred.shape.dims is None or pred.shape.dims)):
pred = array_ops.squeeze_v2(pred)
if maximum_iterations is None:
return pred
else:
return math_ops.logical_and(
loop_counter < maximum_iterations_arg, pred)
# NOTE(skyewm): we set collections to the outer graph's collections for
# compatibility with TPUEstimator.
cond_graph = func_graph_module.func_graph_from_py_func(
cond_name,
wrapped_cond,
[], # We provide signature instead of args.
{},
signature=signature,
func_graph=util.WhileCondFuncGraph(
cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies)
if glob_stateful_parallelism == "stateless_cond":
stateful_parallelism = (not any(
op._is_stateful for op in cond_graph.get_operations()))
else:
stateful_parallelism = glob_stateful_parallelism
def wrapped_body(loop_counter, maximum_iterations_arg, *args):
"""Loop body augmented with counter update.
Args:
loop_counter: Loop counter which needs to be incremented in the body.
maximum_iterations_arg: Maximum iterations of the loop.
*args: List of args
Returns:
A list of tensors the same length as args.
"""
# The function was created with a signature rather than tensors, so
# internal placeholders were created without handle data.
_copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True),
nest.flatten(args, expand_composites=True))
# Capture the tensors already captured in cond_graph so that they appear
# in the same order in body_graph.external_captures.
for t in cond_graph.external_captures:
ops.get_default_graph().capture(t)
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
if not nest.is_sequence_or_composite(outputs):
outputs = [outputs]
# Compare the structure of input and output of body converting the
# top-level tuples to list to be compatible with legacy while_loop.
nest.assert_same_structure(list(outputs), list(orig_loop_vars),
expand_composites=True)
outputs = _tensor_array_to_flow(outputs)
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
body_graph = func_graph_module.func_graph_from_py_func(
body_name,
wrapped_body,
[], # We provide signature instead of args.
{},
signature=signature,
func_graph=util.WhileBodyFuncGraph(
body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies,
acd_record_initial_resource_uses=stateful_parallelism)
# Add external captures of body to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
# beginning of the loop execution.
loop_vars = loop_vars + body_graph.external_captures
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
body_graph.outputs.extend(body_graph.internal_captures)
# Capture the extra `external_captures` of `body_graph` in `cond_graph` so
# that it expects to receive those as arguments.
with cond_graph.as_default():
num_cond_captures = len(cond_graph.external_captures)
assert (cond_graph.external_captures ==
body_graph.external_captures[:num_cond_captures])
_duplicate_body_captures_in_cond(
cond_graph, body_graph.external_captures[num_cond_captures:])
# Make sure that the shapes of the loop outputs are compatible with the
# shape invariants, or the shapes of the loop vars if the invariants are not
# specified.
num_flattened_outputs = len(nest.flatten(orig_loop_vars,
expand_composites=True))
# First var is loop counter and second var is maximum_iterations.
first_loop_var_index = 2
_check_shapes_compat(
body_graph.outputs[first_loop_var_index:first_loop_var_index +
num_flattened_outputs],
nest.flatten(
shape_invariants[first_loop_var_index:first_loop_var_index +
len_orig_loop_vars], expand_composites=True),
nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
len_orig_loop_vars], expand_composites=True))
num_original_outputs = len(body_graph.outputs)
if back_prop and util.output_all_intermediates():
# Export all tensors in the loop body that may be needed for gradient
# computation. We do this by accumulating the intermediate values in
# TensorLists.
intermediate_tensors = _get_intermediates(body_graph)
for intermediate_tensor in intermediate_tensors:
tensor_list = list_ops.empty_tensor_list(
element_dtype=intermediate_tensor.dtype,
element_shape=intermediate_tensor.shape,
max_num_elements=maximum_iterations)
loop_vars.append(tensor_list)
with cond_graph.as_default():
# Add a placeholder to cond_graph's inputs corresponding to the
# tensor_list.
cond_graph.capture(tensor_list)
with body_graph.as_default():
# Push the intermediate tensor to the tensor list. This captures the
# `tensor_list` as well.
appended_tensor_list = list_ops.tensor_list_push_back(
tensor_list, intermediate_tensor)
# Add this modified tensor list to the list of outputs.
body_graph.outputs.append(appended_tensor_list)
flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
_check_num_inputs_outputs(cond_graph, body_graph,
len(flattened_loop_vars))
_check_inputs_outputs_types_match(body_graph, flattened_loop_vars)
with ops.control_dependencies(
list(cond_graph.control_captures) + list(body_graph.control_captures)):
output_shapes = [t.shape for t in body_graph.outputs]
orig_loop_vars_range = slice(first_loop_var_index,
first_loop_var_index + num_flattened_outputs)
output_shapes[orig_loop_vars_range] = nest.flatten(
shape_invariants, expand_composites=True)[orig_loop_vars_range]
outputs = _build_while_op(
flattened_loop_vars,
cond_graph,
body_graph,
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=scope,
num_original_outputs=num_original_outputs,
stateful_parallelism=stateful_parallelism)
if not ops.get_default_graph().building_function:
# In V1 graph mode, return identities for each output of the While op,
# rather than the output of the While op directly. This makes pruning work
# if the output of while_loop() is fetched: the lowering pass converts the
# While outputs into IdentityN outputs, which if fetched will cause all
# ops in the body to be run (since it takes all exit ops as input). After
# lowering, each output identity op will end up with only the appropriate
# exit op as input.
outputs = tuple(array_ops.identity(t) for t in outputs)
output_loop_vars = outputs[first_loop_var_index:first_loop_var_index +
num_flattened_outputs]
if not back_prop:
output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars]
outputs = _pack_sequence_as(orig_loop_vars, output_loop_vars)
if return_same_structure:
return outputs
flattened_outputs = nest.flatten(outputs, expand_composites=True)
if len(flattened_outputs) == 1:
return flattened_outputs[0]
else:
return outputs
@ops.RegisterGradient("StatelessWhile")
@ops.RegisterGradient("While")
def _WhileGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of a While op produced by while_loop."""
# Note that op is not always the same as while_op because the gradient tape,
# for eager mode compatibility, forgets information about the proper op. Since
# the loop cannot run in eager mode, however, we can safely introspect into
# the graph here.
while_op = op.outputs[0].op
cond_graph = _get_graph(while_op, "cond", "_cond_graph")
body_graph = _get_graph(while_op, "body", "_body_graph")
orig_num_params = len(body_graph.outputs)
maximum_iterations = op.inputs[1]
parallel_iterations = op.get_attr("parallel_iterations")
try:
num_original_outputs = while_op.get_attr("_num_original_outputs")
except: # pylint: disable=bare-except
num_original_outputs = len(while_op.outputs)
try:
stateful_parallelism = while_op.get_attr("_stateful_parallelism")
except: # pylint: disable=bare-except
stateful_parallelism = False
num_intermediates = len(while_op.outputs) - num_original_outputs
grads = [
_preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension
for grad, body_out, while_in, while_out in zip(
grads[:num_original_outputs],
body_graph.outputs[:num_original_outputs],
while_op.inputs[:num_original_outputs],
while_op.outputs[:num_original_outputs])
] + [None] * num_intermediates
# Skip gradients with respect to the captures whenever possible.
if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
captures_start_index = (
len(body_graph.inputs) - len(body_graph.internal_captures))
for i in op.skip_input_indices:
if i >= captures_start_index:
grads[i] = None
# We compute the gradient for the sub-graph between trainable ys and xs
# with non-None incoming gradients. We later pad the None's to the list of
# outputs.
ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
body_graph.outputs, body_graph.inputs, grads) if grad is not None])
body_grad_graph, args = _create_grad_func(
ys, xs, non_none_grads, cond_graph, body_graph,
util.unique_grad_fn_name(body_graph.name), op, maximum_iterations,
stateful_parallelism)
if body_grad_graph.while_op_needs_rewrite:
# Modify 'op' to output the intermediate accumulators needed by the grad
# function.
# NOTE(skyewm): if there are any active sessions, this modification to `op`
# may make them unrunnable!
cond_graph.name += "_rewritten"
body_graph.name += "_rewritten"
# `body_grad_graph.extra_inputs` here is equivalent to skimming off the new
# `body_graph.external_captures` added during `_create_grad_func`.
new_inputs = body_grad_graph.extra_inputs
new_outputs = body_graph.outputs[orig_num_params:]
while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs):
# Continuing leads to an invalid graph with disconnected inputs.
raise AssertionError(
"Inputs and outputs constructed for the forward op of a While "
"gradient don't match. This doesn't make sense, please file a bug.")
while_op._set_type_list_attr("T", body_graph.output_types)
while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
while_op._add_while_inputs(new_inputs)
while_op._add_outputs([t.dtype for t in new_outputs],
[t.shape for t in new_outputs])
_copy_handle_data(new_outputs, while_op.outputs[orig_num_params:])
# Do not ignore grads wrt extra outputs when computing higher order
# derivatives.
while_op._set_attr("_num_original_outputs",
attr_value_pb2.AttrValue(i=len(while_op.outputs)))
captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
while_op)
loop_vars = args + captured_inputs
# This modifies body_grad_graph.
loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
grads, body_grad_graph, loop_vars, while_op.inputs)
def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
*unused_args):
return counter < forward_loop_iters
grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
cond_grad_graph = func_graph_module.func_graph_from_py_func(
grad_cond_name, grad_cond, loop_vars, {},
func_graph=util.WhileCondFuncGraph(grad_cond_name))
_check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
outputs = _build_while_op(
loop_vars,
cond_grad_graph,
body_grad_graph,
output_shapes=[t.shape for t in body_grad_graph.outputs],
parallel_iterations=parallel_iterations,
name="%s_grad" % while_op.name,
num_original_outputs=len(body_grad_graph.outputs),
stateful_parallelism=stateful_parallelism)
# See comment in while_loop.
outputs = [array_ops.identity(t) for t in outputs]
return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
parallel_iterations, name, num_original_outputs,
stateful_parallelism):
"""Builds the functional StatelessWhile/While op."""
cond_stateful_ops = [
op for op in cond_graph.get_operations() if op._is_stateful
]
body_stateful_ops = [
op for op in body_graph.get_operations() if op._is_stateful
]
if (cond_stateful_ops or body_stateful_ops):
op_fn = gen_functional_ops._while
else:
op_fn = gen_functional_ops.stateless_while
def _make_op(inputs):
while_op, tensors = util.get_op_and_outputs(op_fn(
inputs,
util.create_new_tf_function(cond_graph),
util.create_new_tf_function(body_graph),
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=name))
_copy_handle_data(body_graph.outputs, tensors)
util.maybe_set_lowering_attr(while_op)
util.maybe_propagate_compile_time_consts_in_xla(while_op)
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
# This is needed so we do not compute derivative wrt these extra outputs.
while_op._set_attr("_num_original_outputs",
attr_value_pb2.AttrValue(i=num_original_outputs))
while_op._set_attr("_stateful_parallelism",
attr_value_pb2.AttrValue(b=stateful_parallelism))
# The while op may be created inside a tf.function, in which case ops
# needs to capture "through" it when taking gradients; outer_graph is used
# as a sanity check that capturing only happens from parent to child.
cond_graph.outer_graph = ops.get_default_graph()
body_graph.outer_graph = ops.get_default_graph()
while_op._cond_graph = cond_graph
while_op._body_graph = body_graph
return tensors
return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
def _get_intermediates(func_graph):
"""Returns all tensors in `func_graph` that should be accumulated."""
# We currently accumulate output tensors of most ops in the function and rely
# on the pruning pass to get rid of the unused accumulators at runtime.
# However, this can bloat the GraphDef and make debugging harder so we perform
# some optimizations.
#
# Optimization we currently perform:
# 1. We do not accumulate tensors which already have an accumulator
# in the loop body.
# 2. We do not accumulate outputs of Identity nodes. When building the
# FuncGraph, we add an Identity node for each output (see
# `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
# of all these nodes bloats the GraphDef quite a bit so we remove those.
# Since the gradient of an Identity node does not rely on its forward op's
# input this is safe to do.
#
# Other possible optimizations:
# 1. Only accumulate tensors that will be required by the backward pass.
# This will require running the gradient pass and hence would increase the
# graph building time for the forward pass.
# 2. Do not accumulate Const nodes created inside the loop body.
# 3. Do not accumulate loop vars that are returned as-is just like captured
# tensors.
intermediates = []
reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
for op in func_graph.get_operations():
if op.type == "Identity":
continue
# Accumulating mutexes can cause deadlock.
if op.type == "MutexLock":
continue
for o in op.outputs:
if (o is not func_graph.inputs[0] and # Loop counter.
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
_get_accumulator(o) is None and # Has existing accumulator.
o.ref() not in reverse_captures
): # Captured value, hence loop invariant.
intermediates.append(o)
return intermediates
def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
"""Returns the initial gradient to be used for a given output tensor.
Args:
grad: the original gradient Tensor passed to the gradient function.
body_graph_output: the corresponding Tensor in the body graph.
while_op_input: the corresponding Tensor input of the While op.
while_op_output: the corresponding Tensor output of the While op.
Returns:
A Tensor or None.
"""
# Set the incoming gradient of non-trainable inputs to None. It is possible
# that we receive non-None gradients for non-trainable types in nested while
# loops because we accumulate outputs of the inner while as variant tensors
# which are trainable and hence receive zeros_like tensors in the gradient
# pass. The non-trainable tensors then receive the popped zeros tensor from
# this zeros variant. The gradient for the loop vars corresponding to these
# tensors is None or zeros (this happens only if the loop var is accumulated
# as well) in _grad_fn so we reset these.
# TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
if not _is_trainable(body_graph_output):
return None
# GradientTape initializes resource and variant grads as None instead of
# zeros. Set to zeros so _GradientsHelper computes the gradients instead of
# returning None.
# TODO(b/143286622): The supports_default_grad check is needed
# because While op emits non-differentiable resource tensors
# as outputs. Remove this check when that is not the case.
# Note: We use `while_op_input` instead of `while_op_output` for the call
# to `supports_default_grad` because `while_op_output` may be missing
# handle_data if the While is in a restored saved model.
if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
default_gradient.supports_default_grad(while_op_input) and grad is None):
return _zeros_like(while_op_input, while_op_output)
# Convert IndexedSlices to dense tensors since it is unlikely that downstream
# gradient functions with properly handle indexed slices. This is similar to
# what we do in tf.function gradients.
if isinstance(grad, ops.IndexedSlices):
return ops.convert_to_tensor(grad)
return grad
# TODO(skyewm): make this return constants if op_output's shape is fully
# defined (this can be done by checking the "shape" attr of resource vars).
def _zeros_like(op_input, op_output):
"""Like array_ops.zeros_like() but also accepts resource var handles."""
if op_output.dtype == dtypes.resource:
# Note: We use `op_input` instead of `op_output` to get the zeros dtype
# because `op_output` may be missing handle_data if the While is in a
# restored saved model.
return array_ops.zeros(
gen_resource_variable_ops.variable_shape(op_output),
dtype=default_gradient.get_zeros_dtype(op_input))
return array_ops.zeros_like(op_output)
def _is_trainable(tensor):
"""Returns whether the given tensor is trainable."""
if not backprop_util.IsTrainable(tensor):
return False
# Special case: untrainable accumulator output. The gradients algorithm
# doesn't know about tensor lists of untrainable elements. In theory the
# tensor list gradient functions should return None as appropriate, but
# because we can't return None from the gradient function we filter out
# untrainable accumulator output here to avoid computing the gradient at all.
if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
assert tensor.dtype == dtypes.variant
element_type = tensor.op.get_attr("element_dtype")
return backprop_util.IsTrainable(element_type)
return True
def _get_graph(while_op, func_attr_name, attr_graph_name):
"""Returns `FuncGraph` for the given function attribute.
Args:
while_op: The While Operation.
func_attr_name: string
attr_graph_name: cached forward graph name
Returns:
`FuncGraph`
"""
func_graph = getattr(while_op, attr_graph_name, None)
if func_graph is None:
# TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
input_shapes = [
tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
]
func_name = while_op.get_attr(func_attr_name).name
func_graph = util.get_func_graph(while_op, input_shapes, func_name)
func_graph._while = while_op
return func_graph
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
maximum_iterations, stateful_parallelism):
"""Builds and returns the gradient FuncGraph of `func_graph` and its args.
The returned grad_func_graph must be called with the returned
args + grad_func_graph.captures.
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
grads: The incoming grads for `ys`.
cond_graph: FuncGraph for the forward cond function.
body_graph: FuncGraph for the forward body function.
name: Name of the returned gradient function.
while_op: The forward While op.
maximum_iterations: Tensor. The maximum number of iterations.
stateful_parallelism: Bool, see tf.while_loop.
Returns:
2-tuple of (grad_func_graph, args).
"""
assert len(ys) == len(grads)
total_iters = while_op.outputs[0]
counter = constant_op.constant(
0, dtype=total_iters.dtype, name="grad_counter")
# Build frozen sets so that we do not have linear time lookups in
# `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
# may get updated during gradient computation because we add accumulators to
# the forward op. However, those are not loop invariants so wouldn't affect
# the output of `_is_loop_invariant`. Also we would never attempt to capture
# those accumulators so `_is_loop_invariant` should never receive those new
# tensors as args.
body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
args = [counter, maximum_iterations, total_iters] + list(grads)
# Note: The returned function does not have `args` in the list of
# `external_captures`.
grad_func_graph = func_graph_module.func_graph_from_py_func(
name,
lambda *args: _grad_fn(ys, xs, args, body_graph),
args, {},
func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
maximum_iterations, while_op,
body_graph_inputs, body_graph_outputs),
acd_record_initial_resource_uses=stateful_parallelism)
# Update the list of outputs with tensors corresponding to the captured
# tensors. We capture 3 types of tensors when building the grad fn:
# 1. Accumulators for forward graph intermediates which are not loop
# invariants. The outputs corresponding to these are populated in
# `internal_capture_to_output` by `_WhileBodyGradFuncGraph`.
# 2. Resources, which are output as is.
# 3. Forward graph loop invariants, which are output as is.
for external_capture, internal_capture in grad_func_graph.captures:
if (ops.tensor_id(internal_capture)
in grad_func_graph.internal_capture_to_output):
new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id(
internal_capture)]
else:
raise ValueError(
"Tensor %s which captures %s is in list of "
"internal_captures but not in internal_capture_to_output." %
(str(internal_capture), str(external_capture)))
grad_func_graph.outputs.append(new_output)
grad_func_graph.structured_outputs.append(new_output)
return grad_func_graph, args
def _grad_fn(ys, xs, args, func_graph):
"""Computes the gradient of `func_graph` in the current graph.
This function builds the gradient graph of the corresponding forward-pass
`func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
args: The input arguments.
args[0] - Loop counter
args[1] - Total number of iterations.
args[2] - maximum_iterations.
args[3:] - Incoming gradients for `ys`.
func_graph: function.FuncGraph. The corresponding forward-pass function.
Returns:
The output gradient Tensors.
"""
grad_ys = args[3:]
# Build the gradient graph. Note that this builds the gradient computation of
# func_graph in the current graph, which requires capturing tensors from
# func_graph. The captured func_graph tensors are resolved to external tensors
# after the forward While op has been rewritten in _resolve_grad_captures.
# TODO(srbs): Mark GradientsHelper as public?
grad_outs = gradients_util._GradientsHelper(
ys, xs, grad_ys=grad_ys, src_graph=func_graph,
unconnected_gradients="zero")
# TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
# is a tf.StopGradient in the loop body.
assert all(g is not None for g in grad_outs)
counter = args[0]
maximum_iterations = args[1]
total_iters = args[2]
return [counter + 1, maximum_iterations, total_iters] + grad_outs
def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
"""Returns the tensors to pass as captured inputs to `body_grad_graph`.
`body_grad_graph` may have external references to:
1. Its outer graph containing the input gradients. These are left as-is.
2. Accumulators captured from the forward-pass graph. These should have been
added as `while_op` outputs after the gradient graph was built. We replace
these with the corresponding output of `while_op`, i.e. a tensor in
`body_graph.outer_graph`. In the case of nested control flow or functions,
the gradient logic handling `body_grad_graph.outer_graph` will make sure
the tensor from `body_graph.outer_graph` is also correctly captured.
Args:
body_graph: FuncGraph. The forward-pass body function.
body_grad_graph: FuncGraph. The body gradients function.
while_op: The forward-pass While Operation calling `body_graph`.
Returns:
A list of input tensors to be passed as the captured inputs to
`body_grad_graph`.
"""
new_capture_inputs = []
for t in body_grad_graph.external_captures:
# Resolve tensors captured from the forward graph to the outputs of the
# forward while_op.
if t.graph == body_graph:
# Captured accumulator or loop invariant.
for i, output in enumerate(t.graph.outputs):
if output is t:
t = while_op.outputs[i]
break
# Note: We rely on the capturing logic of the gradient While op graph to
# correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
# and while_v2 handle this while building their gradient functions.
assert t.graph == body_graph.outer_graph
new_capture_inputs.append(t)
return new_capture_inputs
def _get_structured_grad_output(outputs, grads, body_grad_graph):
"""Returns the values that should be returned from the while grad function.
Args:
outputs: the raw Tensor outputs of the grad While op.
grads: the input gradients to the gradient function.
body_grad_graph: _WhileBodyGradFuncGraph.
Returns:
A list of gradient values. May include Nones.
"""
result = []
# outputs[0] is the loop counter.
# outputs[1] is maximum_iterations.
# outputs[2] is the total number of loop iterations.
outputs_idx = 3
structured_outputs_idx = 3
for g in grads:
# Set None as the output gradient for tensors with None input gradient.
if g is None:
result.append(None)
continue
output = body_grad_graph.structured_outputs[structured_outputs_idx]
structured_outputs_idx += 1
if isinstance(output, ops.IndexedSlices):
# TODO(skyewm): is there a more robust way to determine the order of
# flattened IndexedSlices components?
result.append(ops.IndexedSlices(
values=outputs[outputs_idx],
indices=outputs[outputs_idx + 1],
dense_shape=outputs[outputs_idx + 2]))
outputs_idx += 3
else:
assert isinstance(output, ops.Tensor)
result.append(outputs[outputs_idx])
outputs_idx += 1
return result
def _get_accumulator(tensor):
r"""Returns TensorList if any containing accumulated values of tensor.
We try to find a pattern of the form:
input_tl tensor
\ /
(TensorListPushBack)
|
output_tl
which satisfies the following conditions:
1. input_tl must be in tensor.graph.inputs.
2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
returned if such a pattern is found else None is returned.
Args:
tensor: The Tensor to be accumulated.
Returns:
A variant tensor in the same graph as `tensor` or None if no accumulator is
found.
"""
assert isinstance(tensor.graph, func_graph_module.FuncGraph)
def get_func_graph_output(t):
"""Returns t or Identity(t) whichever exists in graph outputs else None."""
for output in tensor.graph.outputs:
if output is t:
return t
# tf.defun adds an Identity for each output, check whether that is the case.
identity_op = t.consumers()[0]
if (identity_op.type == "Identity" and
any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
return identity_op.outputs[0]
return None
for consumer in tensor.consumers():
# Find the consumer that is a TensorListPushBack node whose TensorList input
# is in the list of function inputs.
if consumer.type != "TensorListPushBack":
continue
accum_input_idx = -1
for accum_input_idx, inp in enumerate(tensor.graph.inputs):
if inp is consumer.inputs[0]:
break
else:
continue
output = get_func_graph_output(consumer.outputs[0])
if output is None:
# The TensorList output of `consumer` is not in the list of function
# outputs.
continue
for accum_output_idx, out in enumerate(tensor.graph.outputs):
if out is output:
if accum_input_idx == accum_output_idx:
return output
break
return None
OptimizedReductionOpsCacheKey = collections.namedtuple(
"OptimizedReductionOpsCacheKey", [
"op_type",
"inputs",
"dtypes",
"input_types",
"name",
"attrs",
"op_def",
"compute_device",
])
class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
"""FuncGraph for the gradient function of the body of a While op.
Contains the logic for capturing the tensors from the body of the forward
While op which is as follows:
1. If the tensor is of resource type (these are not accumulated):
a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
inputs and outputs at the same index.
b. Lookup the corresponding resource tensor in the forward outer graph and
try to capture that.
2. If the tensor is not of resource type:
a. Create an accumulator for that tensor and output it from the forward
pass. Note this also requires adding it as an input to the forward pass.
b. Capture the accumulator from the forward pass in this FuncGraph. This
will later be resolved to the correct output of the forward While op.
c. Pop a value from the captured placeholder and use it as the captured
value for the forward pass tensor.
This only allows capturing tensors in the forward graph. A ValueError is
raised if an attempt is made to capture a tensor not in the forward graph.
To manually capture a tensor that is not in the forward graph, call `capture`
with `allowlisted=True`.
Note: The `captures` dict does not contain the forward tensor since it is not
directly captured. It contains the accumulator corresponding to this forward
tensor.
Attributes:
while_op_needs_rewrite: True if any non-resource intermediates were
captured, meaning the forward While op needs to be rewritten to output the
corresponding accumulators.
extra_inputs: list of EmptyTensorList tensors to be used as initial input to
the new accumulators in the forward graph. It may also contain external
captures of the custom gradient function.
internal_capture_to_output: dict from a tensor_id(captured placeholder) to
the corresponding tensor that needs to be added to the list of outputs.
For instance, when capturing an accumulator TensorList this contains the
TensorList obtained after popping a tensor from the list. Other entries
in this dict are expected, though not enforced, to be identities.
This dict is needed because these output tensors need to be added to
FuncGraph.outputs "after" the tensors returned from the gradient function.
"""
def __init__(self, name, forward_cond_graph, forward_body_graph,
maximum_iterations, forward_while_op, body_graph_inputs,
body_graph_outputs):
super(_WhileBodyGradFuncGraph, self).__init__(name)
self.extra_inputs = []
self.internal_capture_to_output = {}
# FuncGraph for the body of the forward While op.
self._forward_graph = forward_body_graph
# FuncGraph for the cond of the forward While op.
self._forward_cond_graph = forward_cond_graph
self._maximum_iterations = maximum_iterations
self._forward_while_op = forward_while_op
# Dict from forward intermediate tensor to its indirectly captured tensor
# in this graph. Indirect capturing happens in two ways:
# 1. For non-resource tensors we capture their accumulators from the forward
# outer graph and pop values from that accumulator inside this graph
# using TensorListPopBack.
# 2. For resource tensors we directly capture their corresponding tensor
# in the forward outer graph.
self._indirect_captures = {}
@property
def while_op_needs_rewrite(self):
return self.extra_inputs
def _create_op_internal(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# For a reduction op, if op is in the gradient body graph and its input is
# from the forward graph, moving op to the forward graph means we would
# store the tensor after the reduction as opposed to the tensor before
# reduction, and therefore could significantly reduce memory consumption.
# For now, we do this only for a few ops.
#
# We don't do this if any input tensor has already been accumulated. This
# can happen if we output all intermediates in the forward pass.
#
# If in XLA context, do not move constant ops to forward pass as pushing to
# and popping from a TensorList removes the constant property of an op and
# breaks XLA compilation, which requires certain inputs to be compile-time
# constant for certain ops.
#
# This optimization is currently also disabled when under a persistent tape,
# since it leads to an unbounded number of side outputs. With caching it may
# be possible to re-enable it.
optimized_reduction_ops = {
"Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
}
if (op_type in optimized_reduction_ops and
not util.output_all_intermediates() and
all(input.graph is self._forward_graph for input in inputs) and
all(_get_accumulator(input) is None for input in inputs) and
not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
not util.graph_wrapped_for_higher_order_tape_gradients(
self._forward_graph)):
return self._move_op_to_forward_graph(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
def _move_op_to_forward_graph(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# We have a cache of reduction ops that have already been moved to the
# forward graph, and we will check it first to avoid moving an op twice.
if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
self._forward_graph._optimized_reduction_ops_cache = {}
cache_key = self._get_optimized_reduction_ops_cache_key(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device)
cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
cache_key)
if cached_op is not None:
# This op has already been moved to the forward graph and we have it in
# the cache.
return cached_op
with self._forward_graph.as_default():
# `name` was built using name_scope stack of gradient graph and may not
# be unique in the forward graph. `Graph.create_op` does not uniquify
# names which are name scopes i.e. end in `/`. To ensure that the op
# created gets a unique name in the forward graph we get rid of the
# trailing slash.
name = ops.name_from_scope_name(name)
result = self._forward_graph._create_op_internal(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
# Store the op we just moved to the forward graph so that it does
# not need to be added there again.
self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
return result
def _get_optimized_reduction_ops_cache_key(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# We need all elements of CacheKey to be hashable.
inputs = tuple(map(lambda t: t.ref(), inputs))
if dtypes is not None:
dtypes = tuple(dtypes)
if input_types is not None:
input_types = tuple(input_types)
if attrs is not None:
hashable_attrs = []
for attr_name, attr_value in sorted(attrs.items()):
hashable_attrs.append((attr_name, attr_value.SerializeToString()))
attrs = tuple(hashable_attrs)
if op_def is not None:
op_def = op_def.SerializeToString()
return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
name, attrs, op_def, compute_device)
def _capture_helper(self, tensor, name):
"""Implements the capturing described in the class docstring."""
captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
if captured_tensor is not None:
return captured_tensor
if tensor.graph is not self._forward_graph:
already_captured = self.captured(tensor)
captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
tensor, name)
if not already_captured:
# Adds the captured tensor to the list of outputs so that the input
# and output signatures match.
self.internal_capture_to_output[ops.tensor_id(
captured_tensor)] = captured_tensor
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
return captured_tensor
while tensor.op.type == "Identity":
# We do not accumulate the output of identity nodes so we try to capture
# the input of the Identity node instead.
tensor = tensor.op.inputs[0]
captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
if captured_tensor is not None:
return captured_tensor
# No need to accumulate loop invariants. Capture them directly.
# The captured tensor gets resolved to the corresponding while output in
# `_resolve_grad_captures`.
if _is_loop_invariant(tensor, self._forward_graph.inputs,
self._forward_graph.outputs):
captured_tensor = super(_WhileBodyGradFuncGraph,
self)._capture_helper(tensor, name)
# Add to `internal_capture_to_output` so that this gets added to the list
# of outputs.
self.internal_capture_to_output[ops.tensor_id(
captured_tensor)] = captured_tensor
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
return captured_tensor
# Do not accumulate Const nodes. Instead copy them directly in the backward
# graph.
# TODO(srbs): This just checks for `Const` nodes. Consider checking for
# graph compile time consts in general.
# TODO(srbs): Consider making this a loop input.
if constant_op.is_constant(tensor):
real_value = constant_op.constant(
tensor_util.constant_value(tensor), dtype=tensor.dtype)
self._indirect_captures[ops.tensor_id(tensor)] = real_value
return real_value
# Resource tensors are not accumulated and handled specially.
if tensor.dtype == dtypes.resource:
return self._resource_capture_helper(tensor)
# Create or find an existing accumulator output for `tensor` in the forward
# graph, and fetch from this accumulator in the gradient graph to get the
# raw intermediate value.
accumulator = _get_accumulator(tensor)
if accumulator is None:
# Create the initial empty tensor list.
#
# Note: We clear the control dependencies to avoid a cycle in case a
# control tensor has an input path to an output of the forward While.
#
# E.g.:
# x = tf.while_loop(...)
# y = f(x)
# with tf.control_dependencies([y]):
# tf.gradients(y, x)
#
# Since the EmptyTensorList is fed back into the forward While, not
# removing the control edge would cause a cycle.
with self._forward_graph.outer_graph.as_default():
with util.clear_control_inputs():
tensor_list = list_ops.empty_tensor_list(
element_dtype=tensor.dtype,
element_shape=tensor.shape,
max_num_elements=self._maximum_iterations,
name=_build_accumulator_name(tensor))
self.extra_inputs.append(tensor_list)
# Push the intermediate tensor to the tensor list. This captures
# `tensor_list`.
with self._forward_graph.as_default():
accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
# Add the modified tensor list to the list of outputs. This output will be
# all the accumulated values.
self._forward_graph.outputs.append(accumulator)
# Capture in the cond graph as well so the forward cond and body inputs
# match.
with self._forward_cond_graph.as_default():
self._forward_cond_graph.capture(tensor_list)
# Capture the accumulator tensor list in the gradient graph directly from
# the forward graph -- we'll later modify this to capture the final list
# output by the forward While op instead.
captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
accumulator, name)
# Pop the intermediate value from the tensor list in the gradient graph.
new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
captured_accumulator, element_dtype=tensor.dtype)
self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
self.internal_capture_to_output[ops.tensor_id(
captured_accumulator)] = new_tensor_list
return captured_tensor
def _resource_capture_helper(self, tensor):
"""Returns the captured resource tensor.
Resource-type tensors are not accumulated. If a resource tensor exists in
the loop body it must either be a loop input or an output of a nested While
op inside the loop body which had captured the external resource.
Args:
tensor: the external resource Tensor to be captured.
Returns:
Tensor in this graph.
"""
assert tensor.dtype == dtypes.resource
index = util.resource_input_index(
tensor.name, [t.name for t in self._forward_graph.inputs],
{op.name: op.node_def for op in self._forward_graph.get_operations()},
self._forward_graph._functions)
input_placeholder = self._forward_graph.inputs[index]
tensor_in_outer_graph = self._forward_graph._while.inputs[index]
assert input_placeholder.dtype == dtypes.resource
assert tensor_in_outer_graph.dtype == dtypes.resource
# This must be a loop invariant.
assert input_placeholder is self._forward_graph.outputs[index], (
"Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
tensor_in_outer_graph)
return self._indirect_captures[ops.tensor_id(tensor)]
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
for (t, shape, input_t) in zip(output_tensors, shape_invariants,
input_tensors):
if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
raise ValueError(
"Input tensor '%s' enters the loop with shape %s, but has "
"shape %s after one iteration. To allow the shape to vary across "
"iterations, use the `shape_invariants` argument of tf.while_loop to "
"specify a less-specific shape." % (input_t.name, shape, t.shape))
def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
"""Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
assert len(cond_graph.inputs) == num_flattened_loop_vars, (
"cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
num_flattened_loop_vars))
assert len(cond_graph.outputs) == 1, (
"cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
assert len(body_graph.inputs) == num_flattened_loop_vars, (
"body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
num_flattened_loop_vars))
assert len(body_graph.outputs) == num_flattened_loop_vars, (
"body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
num_flattened_loop_vars))
def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs,
flattened_loop_vars):
if inp.dtype != out.dtype:
raise TypeError("Loop var {} enters the loop with type {} "
"but has type {} after 1 iteration.".format(
loop_var.name, inp.dtype, out.dtype))
def _build_cond_placeholders_name_prefix(cond_graph):
return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
"""Creates placeholders for body captures in cond_graph.
This is needed to match signatures of cond and body graphs.
Args:
cond_graph: cond branch graph
body_graph_captures: Tensors which were captured when building the
`body_graph`.
"""
types = [t.dtype.as_datatype_enum for t in body_graph_captures]
# TODO(srbs): Providing a unique prefix does not ensure that there is no
# conflict between the placeholder names and existing nodes in the graph.
# However passing a list of strings may not be performant.
# Ideally we should move `Graph.unique_name` to C++ or make
# `Graph._names_in_use` a trie so that we can find a unique prefix.
# TODO(b/143286622): This should not be required once captures are separated
# from regular loop vars.
placeholders = c_api.TF_CreatePlaceholders(
cond_graph._c_graph, types,
compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
placeholder_ops = [
_OperationWithOutputs(ph.oper, cond_graph)
for ph in placeholders
]
tensors = []
for op, ph, dtype in zip(placeholder_ops, placeholders, types):
tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
op._outputs = [tensor]
tensors.append(tensor)
# Update `cond_graph._captures` and `cond_graph.inputs` to contain the
# newly created placeholders.
tuples = zip(body_graph_captures, tensors)
keys = [id(t) for t in body_graph_captures]
cond_graph._captures.update(zip(keys, tuples))
cond_graph.inputs.extend(tensors)
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
custom_gradient.copy_handle_data(src_t, tgt_t)
def _graph_name(graph):
if isinstance(graph, func_graph_module.FuncGraph):
return graph.name
return "Base"
def _pack_sequence_as(structure_with_tas, loop_vars):
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring
return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary
ta, tensor_array_ops.TensorArray) else flow)
flattened_loop_vars = [
flow_to_tensor_array(*z)
for z in zip(nest.flatten(loop_vars, expand_composites=True),
nest.flatten(structure_with_tas, expand_composites=True))
]
return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars,
expand_composites=True)
def _tensor_array_to_flow(loop_vars):
def f(maybe_ta):
if isinstance(maybe_ta, tensor_array_ops.TensorArray):
return maybe_ta.flow
return maybe_ta
return nest.map_structure(f, loop_vars, expand_composites=True)
def _build_maximum_iterations_loop_var(maximum_iterations):
if maximum_iterations is None:
# Default value for max_num_elements to EmptyTensorList meaning that the
# list size is unbounded.
maximum_iterations = -1
# EmptyTensorList expects `max_num_elements` to be of type int32.
return ops.convert_to_tensor(
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
def _build_accumulator_name(tensor):
# Tensor name may be of the form "pow/y:0". Name scope does not allow ":".
return "{}/accumulator".format(tensor.name).replace(":", "_")
def _is_loop_invariant(tensor, inputs, outputs):
return (any(tensor is t for t in inputs) and
any(tensor is t for t in outputs))
class _OperationWithOutputs(ops.Operation):
"""Operation with pre-built `TF_Output`s.
The C API for creating the extra placeholders for the cond graph returns
SWIG wrapped TF_Output* pointers which we can use directly for
`Operation.outputs`. The default constructor for `Operation` does not provide
a way of specifying pre-built output tensors and always creates them. This is
a performance overhead. It is not clear if adding that feature to the
`Operation` API would be generally useful so for now we just have our own
lightweight `Operation` implementation. Note that this does not extract a
stacktrace as well since we don't expect this operation to be used.
TODO(b/143286622): This should not be required once captures are separated
from regular loop vars.
"""
def __init__(self, c_op, g):
self._c_op = c_op
self._graph = g
self._outputs = None # Initialized by _duplicate_body_captures_in_cond().
self._id_value = g._add_op(self, self.name)
self._is_stateful = False
def _set_read_only_resource_inputs_attr(op, branch_graphs):
"""Sets the list of resource inputs which are read-only.
This is used by AutomaticControlDependencies.
Args:
op: While Operation.
branch_graphs: List of branch FuncGraphs.
"""
read_only_indices = set(range(len(op.inputs)))
for branch_graph in branch_graphs:
if not read_only_indices:
break
branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
branch_graph)
read_only_indices = read_only_indices.intersection(branch_read_only_indices)
ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
sorted(read_only_indices))
# pylint: enable=protected-access