blob: 87f8bd85a560c76d55469a8321c49eb342f69a64 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Control Flow Operations.
See the [Control
Flow](https://tensorflow.org/api_guides/python/control_flow_ops) guide.
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import functools
import os
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
def _summarize_eager(tensor, summarize=None):
"""Returns a summarized string representation of eager `tensor`.
Args:
tensor: EagerTensor to summarize
summarize: Include these many first elements of `array`
"""
# reshape((-1,)) is the fastest way to get a flat array view
if tensor._rank(): # pylint: disable=protected-access
flat = tensor.numpy().reshape((-1,))
lst = [str(x) for x in flat[:summarize]]
if len(lst) < flat.size:
lst.append("...")
else:
# tensor.numpy() returns a scalar for zero dimensional arrays
if summarize != 0:
lst = [str(tensor.numpy())]
else:
lst = []
return ", ".join(lst)
# pylint: disable=protected-access
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
@tf_export("Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
If `condition` evaluates to false, print the list of tensors in `data`.
`summarize` determines how many entries of the tensors to print.
NOTE: In graph mode, to ensure that Assert executes, one usually attaches
a dependency:
```python
# Ensure maximum element of x is smaller or equal to 1
assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
with tf.control_dependencies([assert_op]):
... code using x ...
```
Args:
condition: The condition to evaluate.
data: The tensors to print out when condition is false.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional).
Returns:
assert_op: An `Operation` that, when executed, raises a
`tf.errors.InvalidArgumentError` if `condition` is not true.
@compatibility{eager} returns None.
Raises:
@compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
is not true
"""
if context.executing_eagerly():
if not condition:
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
raise errors.InvalidArgumentError(
node_def=None,
op=None,
message="Expected '%s' to be true. Summarized data: %s" %
(condition, "\n".join(data_str)))
return
with ops.name_scope(name, "Assert", [condition, data]) as name:
xs = ops.convert_n_to_tensor(data)
if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
# As a simple heuristic, we assume that string and int32 are
# on host to avoid the need to use cond. If it is not case,
# we will pay the price copying the tensor to host memory.
return gen_logging_ops._assert(condition, data, summarize, name="Assert")
else:
condition = ops.convert_to_tensor(condition, name="Condition")
def true_assert():
return gen_logging_ops._assert(
condition, data, summarize, name="Assert")
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
if context.executing_eagerly():
return
return guarded_assert.op
def _Identity(data, name=None):
"""Return a tensor with the same shape and contents as the input tensor.
Args:
data: A Tensor.
name: A name for this operation (optional).
Returns:
A Tensor with the same type and value as the input Tensor.
"""
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_array_ops.ref_identity(data, name=name)
else:
return array_ops.identity(data, name=name)
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
values = _Identity(data.values, name=name)
indices = array_ops.identity(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
dense_shape = array_ops.identity(dense_shape, name="dense_shape")
return ops.IndexedSlices(values, indices, dense_shape)
else:
dense_shape = array_ops.identity(data.dense_shape, name="dense_shape")
return sparse_tensor.SparseTensor(indices, values, dense_shape)
def _NextIteration(data, name=None):
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(data, name=name)
else:
return next_iteration(data, name=name)
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
values = _NextIteration(data.values, name=name)
indices = next_iteration(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
dense_shape = next_iteration(dense_shape, name="dense_shape")
return ops.IndexedSlices(values, indices, dense_shape)
else:
dense_shape = next_iteration(data.dense_shape, name="dense_shape")
return sparse_tensor.SparseTensor(indices, values, dense_shape)
def _Enter(data,
frame_name,
is_constant=False,
parallel_iterations=10,
use_ref=True,
use_input_shape=True,
name=None):
"""Creates or finds a child frame, and makes `data` available to it.
The unique `frame_name` is used by the `Executor` to identify frames. If
`is_constant` is true, `data` is a constant in the child frame; otherwise
it may be changed in the child frame. At most `parallel_iterations`
iterations are run in parallel in the child frame.
Args:
data: The tensor to be made available to the child frame.
frame_name: The name of the child frame.
is_constant: If true, the output is constant within the child frame.
parallel_iterations: The number of iterations allowed to run in parallel.
use_ref: If true, use ref_enter if data is of ref type.
name: A name for this operation (optional).
Returns:
The same tensor as `data`.
"""
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
result = gen_control_flow_ops.ref_enter(
data, frame_name, is_constant, parallel_iterations, name=name)
else:
result = gen_control_flow_ops.enter(
data, frame_name, is_constant, parallel_iterations, name=name)
if use_input_shape:
result.set_shape(data.get_shape())
return result
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
values = _Enter(
data.values,
frame_name,
is_constant,
parallel_iterations=parallel_iterations,
use_input_shape=use_input_shape,
name=name)
indices = gen_control_flow_ops.enter(
data.indices,
frame_name,
is_constant,
parallel_iterations,
name="indices")
if use_input_shape:
indices.set_shape(data.indices.get_shape())
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
dense_shape = gen_control_flow_ops.enter(
dense_shape,
frame_name,
is_constant,
parallel_iterations,
name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return ops.IndexedSlices(values, indices, dense_shape)
else:
dense_shape = gen_control_flow_ops.enter(
data.dense_shape,
frame_name,
is_constant,
parallel_iterations,
name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return sparse_tensor.SparseTensor(indices, values, dense_shape)
def exit(data, name=None): # pylint: disable=redefined-builtin
"""Exits the current frame to its parent frame.
Exit makes its input `data` available to the parent frame.
Args:
data: The tensor to be made available to the parent frame.
name: A name for this operation (optional).
Returns:
The same tensor as `data`.
"""
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_control_flow_ops.ref_exit(data, name)
else:
return gen_control_flow_ops._exit(data, name)
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
values = exit(data.values, name=name)
indices = gen_control_flow_ops._exit(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
dense_shape = gen_control_flow_ops._exit(dense_shape, name)
return ops.IndexedSlices(values, indices, dense_shape)
else:
dense_shape = gen_control_flow_ops._exit(data.dense_shape, name)
return sparse_tensor.SparseTensor(indices, values, dense_shape)
def switch(data, pred, dtype=None, name=None):
"""Forwards `data` to an output determined by `pred`.
If `pred` is false, the `data` input is forwarded to the first output.
Otherwise, the data goes to the second output.
This op handles `Tensor`s and `IndexedSlices`.
Args:
data: The tensor to be forwarded to the appropriate output.
pred: A scalar that specifies which output port will receive data.
dtype: Optional element type for the returned tensor. If missing,
the type is inferred from the type of `value`.
name: A name for this operation (optional).
Returns:
`(output_false, output_true)`: If `pred` is true, data will be forwarded
to `output_true`, otherwise it goes to `output_false`.
"""
with ops.name_scope(name, "Switch", [data, pred]) as name:
data = ops.internal_convert_to_tensor_or_indexed_slices(
data, dtype=dtype, name="data", as_ref=True)
pred = ops.convert_to_tensor(pred, name="pred")
if isinstance(data, ops.Tensor):
return gen_control_flow_ops.switch(data, pred, name=name)
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
val, ind = data.values, data.indices
val_f, val_t = gen_control_flow_ops.switch(val, pred, name=name)
ind_f, ind_t = gen_control_flow_ops.switch(ind, pred, name="indices")
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
dense_shape_f, dense_shape_t = gen_control_flow_ops.switch(
dense_shape, pred, name="dense_shape")
else:
dense_shape_f, dense_shape_t = None, None
return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
ops.IndexedSlices(val_t, ind_t, dense_shape_t))
else:
dense_shape = data.dense_shape
dense_shape_f, dense_shape_t = gen_control_flow_ops.switch(
data.dense_shape, pred, name="dense_shape")
return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f),
sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t))
def _SwitchRefOrTensor(data, pred, name="Switch"):
"""Forwards `data` to an output determined by `pred`.
If `pred` is false, the `data` input is forwarded to the first output.
Otherwise, the data goes to the second output.
This op handles `Tensor`s and `IndexedSlices`.
Args:
data: The tensor to be forwarded to the appropriate output.
pred: A scalar that specifies which output port will receive data.
name: A name for this operation (optional).
Returns:
`(output_false, output_true)`: If `pred` is true, data will be forwarded to
`output_true`, otherwise it goes to `output_false`.
Raises:
TypeError: if data is not a Tensor or IndexedSlices
"""
data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
# NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
# addresses the following scenario.
#
# Assume you execute Optimizer.apply_gradients() in a branch of a cond().
#
# 1. The update op is created inside a `with ops.colocate(var):` block
#
# 2. Some tensor `data` is captured and a switch is created in a
# `with ops.colocate_with(data):` block.
#
# with ops.colocate_with(var):
# with ops.colocate_with(data):
# op = ...
#
# var and data may be pinned to different devices, so we want to ops
# created within ops.colocate_with(data) to ignore the existing stack.
with ops.colocate_with(data, ignore_existing=True):
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_switch(data, pred, name=name)
return switch(data, pred, name=name)
def merge(inputs, name=None):
"""Returns the value of an available element of `inputs`.
This op tests each of the tensors in `inputs` in turn to determine if any of
them is available. If it finds an available tensor, it returns it and its
index in `inputs`.
It is an error if more than one tensor in `inputs` is available. If no tensor
in `inputs` is available, the returned tensor and index are not set.
This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
`Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
before merging.
Args:
inputs: The input tensors, at most one of which is available.
name: A name for this operation (optional).
Returns:
A tuple containing the chosen input tensor and its index in `inputs`.
Raises:
ValueError: If any of the inputs is None, or inputs are IndexedSlices and
some but not all have a dense_shape property.
"""
if any([inp is None for inp in inputs]):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
inputs = [
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
for inp in inputs
]
if all([isinstance(v, ops.Tensor) for v in inputs]):
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
return gen_control_flow_ops.ref_merge(inputs, name)
else:
return gen_control_flow_ops.merge(inputs, name)
elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]):
# Only handle the case when all inputs are SparseTensor.
values, _ = merge([inp.values for inp in inputs], name=name)
indices, chosen_index = gen_control_flow_ops.merge(
[inp.indices for inp in inputs], name="indices")
dense_shape, _ = gen_control_flow_ops.merge(
[inp.dense_shape for inp in inputs], name="dense_shape")
return (sparse_tensor.SparseTensor(indices, values, dense_shape),
chosen_index)
else:
# For now convert all the inputs as IndexedSlices.
inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
values, _ = merge([inp.values for inp in inputs], name=name)
indices, chosen_index = gen_control_flow_ops.merge(
[inp.indices for inp in inputs], name="indices")
if any(inp.dense_shape is not None for inp in inputs):
if any(inp.dense_shape is None for inp in inputs):
raise ValueError("Either all merged IndexedSlices must have a "
"dense_shape, or none must have a dense_shape.")
dense_shape, _ = gen_control_flow_ops.merge(
[inp.dense_shape for inp in inputs], name="dense_shape")
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape), chosen_index
# pylint: enable=protected-access
def _convert_tensorarray_to_flow(tensor_or_tensor_array):
if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
return tensor_or_tensor_array.flow
else:
return tensor_or_tensor_array
def _make_tensor_array(ta, t_or_flow):
# pylint: disable=protected-access
new_ta = tensor_array_ops.TensorArray(
dtype=ta.dtype,
handle=ta.handle,
flow=t_or_flow,
infer_shape=ta._infer_shape,
colocate_with_first_write_call=ta._colocate_with_first_write_call)
new_ta._colocate_with = ta._colocate_with
new_ta._element_shape = ta._element_shape
# pylint: enable=protected-access
return new_ta
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
raise ValueError(
"Lengths of original Tensor list and new list do not match: %d vs. %d" %
(len(tensors_or_tensorarrays), len(tensors_or_flows)))
return [
_make_tensor_array(ta, t_or_flow)
if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow
for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
]
def _ShapeLessThanOrEqual(shape1, shape2):
if shape2.dims is None:
return True
if shape1.ndims != shape2.ndims:
return False
for dim1, dim2 in zip(shape1.dims, shape2.dims):
if dim2.value is not None and dim1.value != dim2.value:
return False
return True
def _SetShapeInvariants(input_vars, enter_vars, shapes):
"""Set the shapes of the tensors in `enter_vars` to `shapes`.
Args:
input_vars: A list of tensors that are inputs to `enter_vars`.
enter_vars: A list of tensors whose shapes will be set.
shapes: A (possibly nested) list of shapes.
Raises:
ValueError: If any tensor in `enter_vars` has a less specific shape
than its corresponding shape in `shapes`.
"""
if shapes is None:
return
flat_shapes = nest.flatten(shapes)
if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]):
raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
# Check that the shapes of the inputs are less than the shape invariants,
# and set the shapes of `enter_vars` to the shape invariants.
for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
if isinstance(var, ops.Tensor):
if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
raise ValueError(
"The shape invariant specified for %s is not compatible with "
"the initial shape of the loop variable. It enters the loop "
"with shape %s, but the specified shape invariant is %s." %
(inp.name, inp.get_shape(), shape))
var.set_shape(shape)
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(var))
if isinstance(var, ops.IndexedSlices):
if not _ShapeLessThanOrEqual(inp.values.get_shape(), shape):
raise ValueError(
"The shape invariant specified for %s is not compatible with "
"the initial shape of the values tensor of this IndexedSlices. "
"It enters the loop with shape %s, but the specified shape "
"invariant is %s." % (inp.values.name, inp.values.get_shape(),
shape))
var.values.set_shape(shape)
var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))
if var.dense_shape is not None:
var.dense_shape.set_shape(tensor_shape.TensorShape([shape.ndims]))
else:
if not _ShapeLessThanOrEqual(inp.dense_shape.get_shape(), shape):
raise ValueError(
"The shape invariant specified for %s is not compatible with "
"the initial shape of the shape tensor of this SparseTensor. "
"It enters the loop with shape %s, but the specified shape "
"invariant is %s." % (inp.dense_shape.name,
inp.dense_shape.get_shape(), shape))
var.values.set_shape(tensor_shape.TensorShape([None]))
var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))
var.dense_shape.set_shape(shape)
def _EnforceShapeInvariant(merge_var, next_var):
"""Check if the shapes of the loops variables are invariants.
Args:
merge_var: The list of tensors representing the initial values of the
loop variables.
next_var: The list of tensors representing the values of the loop
variables after one loop iteration.
Raises:
ValueError: If any tensor in `merge_var` has a more specific shape than
its correspnding tensor in `next_var`.
"""
if isinstance(merge_var, ops.Tensor):
m_shape = merge_var.get_shape()
n_shape = next_var.get_shape()
if not _ShapeLessThanOrEqual(n_shape, m_shape):
enter = merge_var.op.inputs[0].op
assert util.IsLoopEnter(enter)
input_t = enter.inputs[0]
raise ValueError(
"Input tensor '%s' enters the loop with shape %s, but has shape %s "
"after one iteration. To allow the shape to vary across iterations, "
"use the `shape_invariants` argument of tf.while_loop to specify a "
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
if not isinstance(merge_var,
(ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(merge_var))
if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
if merge_var.dense_shape is not None:
m_shape_shape = merge_var.dense_shape.get_shape()
n_values_shape = next_var.values.get_shape()
n_indices_shape = next_var.indices.get_shape()
n_shape_shape = tensor_shape.TensorShape(None)
if next_var.dense_shape is not None:
n_shape_shape = next_var.dense_shape.get_shape()
if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape)):
if not _ShapeLessThanOrEqual(n_values_shape, m_values_shape):
raise ValueError(
"The shape for %s is not an invariant for the loop. It enters "
"the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
"after one iteration. Provide shape invariants using either the "
"`shape_invariants` argument of tf.while_loop or set_shape() "
"on the loop variables." %
(merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
n_values_shape, n_indices_shape, n_shape_shape))
else:
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = merge_var.dense_shape.get_shape()
n_values_shape = next_var.values.get_shape()
n_indices_shape = next_var.indices.get_shape()
n_shape_shape = next_var.dense_shape.get_shape()
if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or
not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):
raise ValueError(
"The shape for %s is not an invariant for the loop. It enters "
"the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
"after one iteration. Provide shape invariants using either "
"the `shape_invariants` argument of tf.while_loop or set_shape() "
"on the loop variables." %
(merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
n_values_shape, n_indices_shape, n_shape_shape))
def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
"""Add NextIteration and back edge from v to m."""
if isinstance(m, ops.Tensor):
v = ops.convert_to_tensor(v)
v = _NextIteration(v)
if enforce_shape_invariant:
# Make sure the shapes of loop outputs are correct. We do this before
# calling _update_input, which will raise a less-helpful error message if
# the types don't match.
# TODO(skyewm): call this for other cases below (needs testing)
_EnforceShapeInvariant(m, v)
m.op._update_input(1, v) # pylint: disable=protected-access
elif isinstance(m, ops.IndexedSlices):
# pylint: disable=protected-access
v = math_ops._as_indexed_slices(v, optimize=False)
v = _NextIteration(v)
m.values.op._update_input(1, v.values)
m.indices.op._update_input(1, v.indices)
# pylint: enable=protected-access
if m.dense_shape is not None:
if v.dense_shape is None:
raise ValueError("Must have dense shape: %s" % v.name)
m.dense_shape.op._update_input(1, v.dense_shape)
elif isinstance(m, sparse_tensor.SparseTensor):
if not isinstance(v, sparse_tensor.SparseTensor):
raise ValueError("Must be a sparse tensor: %s" % v.name)
v = _NextIteration(v)
# pylint: disable=protected-access
m.values.op._update_input(1, v.values)
m.indices.op._update_input(1, v.indices)
m.dense_shape.op._update_input(1, v.dense_shape)
# pylint: enable=protected-access
else:
raise TypeError("Type %s not supported" % type(m))
return v
def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
"""Calculate a max_size for use by stack ops inside an XLA while_loop.
Args:
value: The value inside the while_loop forward context. Used for printing
error messages.
while_ctxt: The forward context inside which value resides. This does
not always match the value's immediate context, as `value` may be
inside e.g. a cond context inside the while_loop.
Returns:
A tensor containing the `max_size` to feed to a Stack initializer.
Raises:
ValueError: If `value` is nested inside a `while_loop` that either
lacks a `maximum_iterations` parameter, or the `maximum_iterations`
parameter:
- is inside a `while_loop` that is a parent of the calling context, and
- cannot be evaluated at graph build time to a constant.
"""
value_name = value.name
# curr_ctxt is the context that tf.gradients was called in.
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
max_size = constant_op.constant(1)
# Loop through all containing while contexts between value and the
# current context, multiplying together each context's
# max_iterations to get the maximum stack size.
while while_ctxt not in (None, curr_ctxt):
max_iter = while_ctxt.maximum_iterations
if max_iter is None:
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside "
"XLA while_loop because maximum_iterations was not passed to "
"the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
# pylint: disable=protected-access
max_iter_ctxt = max_iter.op._get_control_flow_context()
# pylint: enable=protected-access
# If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
max_size *= max_iter
else:
# We cannot use max_iter because it's defined in a nested while
# or cond context, so will fail if we try to use it as input to
# any ops in curr_ctxt (e.g. max_size or the final accumulator
# stack). Attempt to get a constant value out to use instead.
const_max_iter = tensor_util.constant_value(max_iter)
if const_max_iter is None:
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside XLA "
"while_loop. maximum_iterations tensor '%s' for while_loop context "
"'%s' must be statically known (e.g. a constant value or known "
"shape dimension), or be defined at or outside the while loop "
"context '%s' (currently defined in '%s')." %
(value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
max_iter_ctxt.name))
max_size *= const_max_iter
# Find the next outer WhileContext (or stop if we reach the
# tf.gradient's context).
while_ctxt = util.GetContainingWhileContext(
while_ctxt.outer_context, stop_ctxt=curr_ctxt)
return max_size
class GradLoopState(object):
"""The state used for constructing the gradient graph for a while loop.
We create a GradLoopState for each while loop in forward and its
corresponding while loop in backprop. This gives us access to both
the forward and the backprop WhileContexts.
During the construction of gradient graph, any time when we detect
a forward value that is needed for backprop, we create a history
accumulator and add it to `history_map`. Any time when we backprop
a loop switch op (in _SwitchGrad), we add the grad merge op in
`switch_map`.
"""
def __init__(self, forward_ctxt, outer_grad_state):
# The grad loop state for the outer while loop.
self._outer_grad_state = None
# The while loop context for forward.
self._forward_context = None
# The loop counter added by AddForwardLoopCounter. It is the value
# of the loop counter for the next iteration.
self._forward_index = None
# A sync op for forward.
self._forward_sync = None
# The while loop context for backprop.
self._grad_context = None
# The loop counter added by AddBackpropLoopCounter. It is the value
# of the loop counter for the current iteration.
self._grad_index = None
# A sync op for backprop.
self._grad_sync = None
# Information needed by backprop.
self._history_map = {}
self._switch_map = {}
self._unused_exits = []
self._deferred_exits = []
self._forward_loop_exits = list(forward_ctxt.loop_exits)
self._pending_exits_count = len(forward_ctxt.loop_exits)
self._outer_grad_state = outer_grad_state
if outer_grad_state:
outer_forward_ctxt = outer_grad_state.forward_context
else:
if not hasattr(forward_ctxt, "outer_context"):
raise ValueError("Failed to call gradients on a while loop without"
"properly serializing graph via MetaGraphDef")
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
if outer_forward_ctxt:
outer_forward_ctxt.Enter()
cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
if outer_forward_ctxt:
outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
# Add the backprop WhileContext, and the backprop loop counter.
if outer_grad_state:
# This is a nested loop. Remember the iteration counts for each
# execution of this inner loop.
outer_forward_ctxt.AddName(cnt.name)
history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
self._grad_context = WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
back_prop=forward_ctxt.back_prop,
swap_memory=forward_ctxt.swap_memory,
name=forward_ctxt.name,
grad_state=self)
real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
real_cnt, outer_grad_state)
outer_grad_ctxt.Exit()
else:
if outer_forward_ctxt:
outer_forward_ctxt.Enter()
self._grad_context = WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
back_prop=forward_ctxt.back_prop,
swap_memory=forward_ctxt.swap_memory,
name=forward_ctxt.name,
grad_state=self)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
cnt, outer_grad_state)
if outer_forward_ctxt:
outer_forward_ctxt.Exit()
@property
def outer_grad_state(self):
"""The grad loop state for outer loop."""
return self._outer_grad_state
@property
def forward_context(self):
"""The while loop context for forward."""
return self._forward_context
@property
def forward_index(self):
"""The loop index of forward loop."""
return self._forward_index
@property
def forward_sync(self):
"""A control trigger node for synchronization in the forward loop.
One main use is to keep the push ops of a stack executed in the
iteration order.
"""
if self._forward_sync is None:
with ops.control_dependencies(None):
self._forward_sync = control_trigger(name="f_sync")
self._forward_sync._set_control_flow_context(self._forward_context)
self._forward_index.op._add_control_input(self._forward_sync)
return self._forward_sync
@property
def grad_context(self):
"""The corresponding WhileContext for gradient."""
return self._grad_context
@property
def grad_index(self):
"""The loop index of backprop loop."""
return self._grad_index
@property
def grad_sync(self):
"""A control trigger node for synchronization in the grad loop.
One main use is to keep the pop ops of a stack executed in the
iteration order.
"""
if self._grad_sync is None:
with ops.control_dependencies(None):
self._grad_sync = control_trigger(name="b_sync")
self._grad_sync._set_control_flow_context(self._grad_context)
self._grad_index.op._add_control_input(self._grad_sync)
if self._grad_context.outer_context:
self._grad_context.outer_context.AddInnerOp(self._grad_sync)
return self._grad_sync
@property
def history_map(self):
"""The map that records all the tensors needed for backprop."""
return self._history_map
@property
def switch_map(self):
"""The map that records all the Switch ops for the while loop."""
return self._switch_map
@property
def unused_exits(self):
"""The list of "unused" exits."""
return self._unused_exits
@property
def deferred_exits(self):
"""The list of "deferred" exits."""
return self._deferred_exits
@property
def forward_loop_exits(self):
"""The list of exits of the forward loop."""
return self._forward_loop_exits
@property
def pending_exits_count(self):
"""The number of exits we expect to see but haven't."""
return self._pending_exits_count
@pending_exits_count.setter
def pending_exits_count(self, cnt):
"""Set the pending count to cnt."""
self._pending_exits_count = cnt
def AddForwardAccumulator(self, value, dead_branch=False):
"""Add an accumulator for each forward tensor that is needed in backprop.
This is added to the forward loop at the first time when a tensor
in the forward loop is used by backprop gradient computation loop.
We create an accumulator that accumulates the value of tensor at each
iteration. Called in the control flow context where gradients() is called.
The pseudocode is:
```
acc = stack();
while (_pivot) {
acc = stack_push(acc, value);
}
```
We make sure that the stack push op in one iteration is executed before
next iteration. This is achieved by adding a control edge from
`forward_index.op.inputs[0].op` to the push op, and another control
edge from the push op to either `forward_index.op` or `forward_sync`.
Args:
value: The source tensor in forward that is to be accumulated.
dead_branch: True iff the tensor is on a dead branch of a cond.
Returns:
The stack that contains the accumulated history of the tensor.
Raises:
TypeError: For internal errors involving the value condition context.
ValueError: If `value` is inside a XLA scope and a valid max size
for the stack can't be found.
"""
# curr_ctxt is the context that tf.gradients was called in.
with self._forward_index.graph.as_default():
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
with ops.control_dependencies(None):
if curr_ctxt:
curr_ctxt.Enter()
with ops.colocate_with(value):
# We only need to pass maximum_iterations to the stack if
# we're inside an XLA context.
if not util.IsInXLAContext(value.op):
max_size = constant_op.constant(-1, dtypes.int32)
else:
max_size = GetMaxSizeFromNestedMaximumIterations(
value, self.forward_context)
acc = gen_data_flow_ops.stack_v2(
max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
if curr_ctxt:
curr_ctxt.Exit()
# Make acc available in the forward context.
enter_acc = self.forward_context.AddValue(acc)
# Add the stack_push op in the context of value.op.
swap_enabled = self.forward_context.swap_memory
value_ctxt = util.GetOutputContext(value.op)
if value_ctxt == self.forward_context:
# value is not nested in the forward context.
self.forward_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
self.forward_context.Exit()
# Protect stack push and order it before forward_index.
self.forward_index.op._add_control_input(push.op)
else:
# value is in a cond context within the forward context.
if not isinstance(value_ctxt, CondContext):
raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
if dead_branch:
# The special case for creating a zero tensor for a dead
# branch of a switch. See ControlFlowState.ZerosLike().
value_ctxt.outer_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
value_ctxt.outer_context.Exit()
push.op._set_control_flow_context(value_ctxt)
else:
value_ctxt.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
value_ctxt.Exit()
# Protect stack push and order it before forward_sync.
self.forward_sync._add_control_input(push.op)
# Order stack push after the successor of forward_index
add_op = self.forward_index.op.inputs[0].op
push.op._add_control_input(add_op)
return acc
def AddBackpropAccumulatedValue(self, history_value, value,
dead_branch=False):
"""Add the getter for an accumulated value in the grad context.
This is added to the backprop loop. Called in the grad context to
get the value of an accumulated value. The stack pop op must be guarded
by the pred of the controlling cond.
Args:
history_value: The history (a stack) of a value.
value: The value that is pushed onto the stack.
dead_branch: True iff the tensor is on a dead branch of a cond.
Returns:
The current value (the top of the stack).
"""
history_ctxt = history_value.op._get_control_flow_context()
# Find the cond context that controls history_value if any.
cond_ctxt = None
value_ctxt = value.op._get_control_flow_context()
while value_ctxt and value_ctxt != history_ctxt:
if isinstance(value_ctxt, CondContext):
cond_ctxt = value_ctxt
break
value_ctxt = value_ctxt.outer_context
with ops.control_dependencies(None):
self.grad_context.Enter()
if cond_ctxt:
# Guard stack pop with a switch if it is controlled by a cond.
grad_state = self
pred = None
while pred is None and grad_state:
pred = grad_state.history_map.get(cond_ctxt.pred.name)
grad_state = grad_state.outer_grad_state
if pred is None:
pred = cond_ctxt.pred
branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
history_value = _SwitchRefOrTensor(history_value, pred)[branch]
pop = gen_data_flow_ops.stack_pop_v2(history_value,
value.dtype.base_dtype)
pop.set_shape(value.get_shape())
self.grad_context.Exit()
parallel_iterations = self.grad_context.parallel_iterations
if parallel_iterations > 1:
# All pops are ordered after pivot_for_body and before grad_sync.
self.grad_sync._add_control_input(pop.op)
return pop
def GetRealValue(self, value):
"""Get the real value of `value`.
If backprop "uses" a value produced by forward inference, an accumulator
is added in the forward loop to accumulate its values. We use the
accumulated value. This method must be called in the grad loop context.
`value` must be in forward and needed for backprop.
Args:
value: A tensor to be captured.
Returns:
The same tensor obtained from the saved history.
"""
assert value.op.type not in ["Variable", "VariableV2"]
real_value = self._history_map.get(value.name)
if real_value is None:
cur_value = value
cur_grad_state = self
while True:
enter_op = util.GetLoopConstantEnter(cur_value)
if enter_op:
# Special case: cur_value comes from a constant Enter node.
cur_value = enter_op.inputs[0]
cur_grad_state = cur_grad_state.outer_grad_state
if cur_grad_state is None:
# We are now outside all nested loops for this gradient(),
# so `value` is a loop invariant and there is no need to
# save the history of value. Just make cur_value to enter
# the right control flow context.
real_value = self._grad_context.AddValue(cur_value)
break
elif constant_op.is_constant(cur_value):
# If the value to be forwarded is a constant, clone the constant in
# the gradient loop rather than using a stack.
# TODO(phawkins): consider hoisting the constant out of the loop
# instead.
real_value = constant_op.constant(
tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
break
else:
# Record the history of this value in forward_ctxt.
self._grad_context.Exit()
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
break
if real_value is None:
# Add the stack pop op in the grad context.
real_value = cur_grad_state.AddBackpropAccumulatedValue(
history_value, cur_value)
if cur_grad_state != self:
real_value = self._grad_context.AddValue(real_value)
self._history_map[value.name] = real_value
return real_value
def _GetWhileContext(op):
"""Get the WhileContext to which this op belongs."""
ctxt = op._get_control_flow_context()
if ctxt:
ctxt = ctxt.GetWhileContext()
return ctxt
class ControlFlowState(object):
"""Maintain the mapping from the loops to their grad states."""
def __init__(self):
self._map = {} # maps forward loop context to GradLoopState
def GetGradState(self, op, before):
"""Return the grad state for this op if it's in a forward loop context."""
if before and util.IsLoopExit(op):
forward_ctxt = op._get_control_flow_context()
forward_ctxt = forward_ctxt.outer_context
if forward_ctxt:
forward_ctxt = forward_ctxt.GetWhileContext()
else:
forward_ctxt = _GetWhileContext(op)
if forward_ctxt:
return self._map.get(forward_ctxt)
return None
def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
"""Process all the "unused" loop exits.
The "unused" exits of the loops are added to `unused_exits`. An exit is
unused if its pending_count is 0. If there is an exit with real gradient,
all these deferred exits will enter the backprop loop with zero gradient.
Otherwise, they will enter the backprop loop with None. As an example,
people often write:
```python
v1, _ = tf.while_loop(p, b, [x1, x2])
result = gradients(v1, x1)
```
The exit node for x2 is not included by the betweenness analysis. But we
need to backprop x2 if x2 is involved in computing v1.
Args:
pending_count: The number of backprop inputs for every op.
to_ops_set: The set of ops for ys in gradients(ys, xs)
Returns:
The set of unused loop exits that we know at this point we need
to backprop.
"""
loop_exits = []
for grad_state in self._map.values():
for y in grad_state.forward_loop_exits:
if pending_count[y.op] == 0:
grad_state.pending_exits_count -= 1
if y.op not in to_ops_set:
grad_state.unused_exits.append(y)
if grad_state.pending_exits_count == 0:
loop_exits.extend(grad_state.unused_exits)
# Need to include Enters in backprop for higher-order gradients.
for y in grad_state.forward_context.loop_enters:
if pending_count[y.op] == 0:
pending_count[y.op] = 1
return loop_exits
def EnterGradWhileContext(self, op, before):
"""Enter the WhileContext for gradient computation."""
grad_state = self.GetGradState(op, before)
if grad_state:
grad_state.grad_context.Enter()
def ExitGradWhileContext(self, op, before):
"""Exit the WhileContext for gradient computation."""
grad_state = self.GetGradState(op, before)
if grad_state:
grad_state.grad_context.Exit()
def AddWhileContext(self, op, between_op_list, between_ops):
"""Add the grad state for the while loop that op belongs to.
Note that op is an Exit, and this method must be called in
the control flow context where gradients() is called.
Note that this method modifies `between_op_list` and `between_ops`.
"""
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
if grad_state is None:
# This is a new while loop so create a grad state for it.
outer_forward_ctxt = forward_ctxt.outer_context
if outer_forward_ctxt:
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
outer_grad_state = None
if outer_forward_ctxt:
outer_grad_state = self._map.get(outer_forward_ctxt)
grad_state = GradLoopState(forward_ctxt, outer_grad_state)
self._map[forward_ctxt] = grad_state
# We need to include all exits of a loop for backprop.
for loop_exit in grad_state.forward_loop_exits:
if loop_exit.op not in between_ops:
between_ops.add(loop_exit.op)
between_op_list.append(loop_exit.op)
def ZerosLikeForExit(self, val):
"""Create zeros_like gradient for a loop exit.
If the result of a loop variable is not used but is involved in
computing the result of some needed loop variable, we create a
zero-valued tensor that is fed as gradient for the Exit node of that
loop variable. Note that val.op is an Exit, and this method must be
called in the control flow context where gradients() is called.
Args:
val: The output tensor of an Exit op.
Returns:
A zero tensor of the same shape of val.
"""
val_shape = val.get_shape()
forward_ctxt = val.op._get_control_flow_context()
outer_forward_ctxt = forward_ctxt.outer_context
if outer_forward_ctxt:
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
outer_grad_state = None
if outer_forward_ctxt:
outer_grad_state = self._map.get(outer_forward_ctxt)
if outer_grad_state:
# This is a nested loop.
if val_shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor
# with the right shape in the right context.
outer_grad_state.grad_context.Enter()
result = array_ops.zeros(val_shape.dims, val.dtype)
outer_grad_state.grad_context.Exit()
else:
# Only the shape of value is needed for backprop.
forward_ctxt.outer_context.Enter()
shape = array_ops.shape_internal(val, optimize=False)
forward_ctxt.outer_context.Exit()
# Save the shape to a stack.
history_shape = outer_grad_state.AddForwardAccumulator(shape)
# Get the shape back from the stack.
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_shape, shape)
result = array_ops.zeros(real_shape, val.dtype)
outer_grad_ctxt.Exit()
else:
# This is not a nested loop.
if val_shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor
# with the right shape.
result = array_ops.zeros(val_shape.dims, val.dtype)
else:
result = array_ops.zeros_like(val, optimize=False)
return result
def ZerosLike(self, op, index):
"""Create zeros_like for the specified output of an op.
If op is in a while loop that is part of gradients(), this method
must be called in its grad loop context.
Args:
op: A tensorflow operation.
index: the index for a specific output of the op.
Returns:
A zero tensor of the same shape of op.outputs[index].
"""
if util.IsLoopSwitch(op):
return None
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
if grad_state is None:
# op is not in a while loop that is part of gradients().
return ZerosLikeOutsideLoop(op, index)
op_ctxt = op._get_control_flow_context()
val = ops.convert_to_tensor(op.outputs[index], name="tensor")
shape = val.get_shape()
if shape.is_fully_defined():
# If the shape is known statically, just create a zero tensor with
# the right shape in the grad loop context.
result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
if dead_branch:
# op is a cond switch. Guard the zero tensor with a switch.
pred = grad_state.history_map.get(op_ctxt.pred.name)
branch = op_ctxt.branch
result = _SwitchRefOrTensor(result, pred)[1 - branch]
else:
# Unknown shape so keep a history of the shape at runtime.
if dead_branch:
# Need to add a special switch to guard the value.
pred = op_ctxt.pred
branch = op_ctxt.branch
op_ctxt.outer_context.Enter()
val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
zeros_shape = array_ops.shape_internal(val, optimize=False)
op_ctxt.outer_context.Exit()
val.op._set_control_flow_context(op_ctxt)
zeros_shape.op._set_control_flow_context(op_ctxt)
else:
op_ctxt.Enter()
zeros_shape = array_ops.shape_internal(val, optimize=False)
op_ctxt.Exit()
# Add forward accumulator for shape.
grad_state.grad_context.Exit()
history_zeros_shape = grad_state.AddForwardAccumulator(
zeros_shape, dead_branch=dead_branch)
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
def PostProcessing(self):
"""Perform postprocessing at the end of gradients().
We have created the gradient graph at this point. So this function
can be used to perform any postprocessing on the gradient graph.
We currently perform the following postprocessing:
1. Patch the gradient graph if the output of a loop variable
doesn't depend on its input.
"""
for _, grad_state in self._map.items():
for _, b_merge in grad_state.switch_map.items():
if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
# The value of this loop variable at iteration i+1 doesn't
# depend on its value at iteration i. So use zeros as the
# gradients for all iterations > 0.
dtype = b_merge.op.inputs[0].dtype
shape = b_merge.op.inputs[0].get_shape()
# pylint: disable=protected-access
if shape.is_fully_defined():
grad_state.grad_context.Enter()
# Create a zeros and use it for iterations > 0.
grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
next_grad_val = _NextIteration(grad_val)
grad_state.grad_context.Exit()
else:
# Create a zeros in the outer grad context.
outer_grad_ctxt = grad_state.grad_context.outer_context
if outer_grad_ctxt:
outer_grad_ctxt.Enter()
enter_grad_op = b_merge.op.inputs[0].op
enter_grad = enter_grad_op.inputs[0]
grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
grad_val = array_ops.zeros(grad_shape)
if outer_grad_ctxt:
outer_grad_ctxt.Exit()
# Use the zeros for iterations > 0.
grad_state.grad_context.Enter()
next_grad_val = _NextIteration(grad_val)
grad_state.grad_context.Exit()
b_merge.op._update_input(1, next_grad_val)
# pylint: enable=protected-access
def MaybeCreateControlFlowState(between_op_list, between_ops,
colocate_gradients_with_ops):
"""Create the state for all the while loops involved in one gradients().
We create a ControlFlowState when there are while loops involved in
gradients(). In gradients(), control flow logic is only invoked when
the ControlFlowState is not None.
Note that this method modifies `between_op_list` and `between_ops`.
"""
loop_state = None
for op in between_op_list:
if util.IsLoopExit(op):
if loop_state is None:
loop_state = ControlFlowState()
if colocate_gradients_with_ops:
with ops.colocate_with(op):
loop_state.AddWhileContext(op, between_op_list, between_ops)
else:
loop_state.AddWhileContext(op, between_op_list, between_ops)
return loop_state
def ZerosLikeOutsideLoop(op, index):
"""Create zeros_like for the specified output of an op."""
val = op.outputs[index]
if not util.IsSwitch(op):
if val.dtype == dtypes.resource:
return array_ops.zeros(gen_resource_variable_ops.variable_shape(val))
return array_ops.zeros_like(val, optimize=False)
else:
op_ctxt = op._get_control_flow_context()
if op_ctxt:
# We are in a cond context. Use a switch to create zeros only when needed.
pred = op_ctxt.pred
branch = op_ctxt.branch
switch_val = switch(op.inputs[0], pred)[1 - branch]
# A op is created along the branch taken as control dependencies are on
# the whole op and not on the tensor output.
pivot = array_ops.identity(switch_val)
if val.dtype == dtypes.resource:
with ops.control_dependencies([pivot]):
return array_ops.zeros(
gen_resource_variable_ops.variable_shape(switch_val))
zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
# Ensure ops created within array_ops.zeros are dominated by switch in
# cond context.
with ops.control_dependencies([pivot]):
return array_ops.zeros(zeros_shape, dtype=val.dtype)
else:
return array_ops.zeros_like(val, optimize=False)
class ControlFlowContext(object):
"""The base class for control flow context.
The usage pattern is a sequence of (Enter, Exit) followed by a final
ExitResult.
We maintain the following state for control flow contexts during graph
construction:
1. graph has _control_flow_context: the current context used to
construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
2. op has _control_flow_context: the context to which the op belongs.
Set at the time the op is created. Immutable.
3. A ControlFlowContext has _outer_context: the context in which this
context is created. Set at the time a context is created. Immutable.
4. A ControlFlowContext has _context_stack.
Pushed and popped by ctxt.Enter() and ctxt.Exit()
"""
def __init__(self, values_def=None, import_scope=None):
self._nested_contexts = []
self._outer_context = ops.get_default_graph()._get_control_flow_context()
if self._outer_context:
self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access
self._context_stack = []
if values_def:
self._init_values_from_proto(values_def, import_scope=import_scope)
else:
# The names of tensors that have been already seen in this context.
self._values = set()
# The keys are the names of tensors referenced by but external to this
# context. Each value is the Tensor that should be used by this context to
# access the key value (e.g. a switch output guarding a cond input value).
self._external_values = {}
def _init_values_from_proto(self, values_def, import_scope=None):
"""Initializes values and external_values from `ValuesDef` protocol buffer.
Args:
values_def: `ValuesDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(values_def, control_flow_pb2.ValuesDef)
self._values = set(
ops.prepend_name_scope(value, import_scope)
for value in values_def.values)
g = ops.get_default_graph()
self._external_values = {}
for k, v in values_def.external_values.items():
k = ops.prepend_name_scope(k, import_scope)
self._external_values[k] = g.as_graph_element(
ops.prepend_name_scope(v, import_scope))
op_names = set([
op.split(":")[0]
for op in self._values - set(self._external_values.keys())
])
for op in op_names:
# pylint: disable=protected-access
g.as_graph_element(op)._set_control_flow_context(self)
# pylint: enable=protected-access
@property
def name(self):
return self._name
@property
def outer_context(self):
"""Return the context containing this context."""
return self._outer_context
@property
def grad_state(self):
raise NotImplementedError("Abstract method")
@property
def back_prop(self):
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def to_control_flow_context_def(self, context_def, export_scope=None):
"""Serializes this into `context_def`.
Args:
context_def: a `ControlFlowContextDef` protocol buffer.
export_scope: Optional `string`. Name scope to remove.
"""
raise NotImplementedError("Abstract method")
def _to_values_def(self, export_scope=None):
"""Converts the values to a `ValuesDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `ValuesDef` protocol buffer.
"""
values_def = control_flow_pb2.ValuesDef()
values_def.values.extend(
[ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
for k, v in self._external_values.items():
k = ops.strip_name_scope(k, export_scope)
values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
return values_def
def AddName(self, name):
self._values.add(name)
# pylint: disable=protected-access
def Enter(self):
"""Enter this control flow context."""
graph = ops.get_default_graph()
self._context_stack.append(graph._get_control_flow_context())
graph._set_control_flow_context(self)
def Exit(self):
"""Exit this control flow context."""
graph = ops.get_default_graph()
last_context = self._context_stack.pop()
graph._set_control_flow_context(last_context)
def EnterGradientColocation(self, op, gradient_uid):
"""Start building a gradient colocated with an op."""
if self._outer_context:
self._outer_context.EnterGradientColocation(op, gradient_uid)
def ExitGradientColocation(self, op, gradient_uid):
"""Start building a gradient colocated with an op."""
if self._outer_context:
self._outer_context.ExitGradientColocation(op, gradient_uid)
def ExitResult(self, result):
"""Make a list of tensors available in the outer context."""
if self._outer_context:
nest.map_structure(lambda x: self._outer_context.AddName(x.name), result)
def GetWhileContext(self):
"""Return the while context containing this context."""
if self._outer_context:
return self._outer_context.GetWhileContext()
return None
def _IsInOuterContext(self, op):
op_ctxt = util.GetOutputContext(op)
outer_ctxt = self.outer_context
while outer_ctxt != op_ctxt:
if outer_ctxt is None:
return False
outer_ctxt = outer_ctxt.outer_context
return True
def _RemoveExternalControlEdges(self, op):
"""Remove any external control dependency on this op."""
while_ctxt = self.GetWhileContext()
# A control input of `op` is internal if it is in the same while
# loop context as the enclosing while loop context of self.
if while_ctxt is None:
internal_control_inputs = op.control_inputs
else:
internal_control_inputs = []
for x in op.control_inputs:
ctxt = util.GetOutputContext(x)
if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
internal_control_inputs.append(x)
external_control_inputs = []
if len(internal_control_inputs) != len(op.control_inputs):
external_control_inputs = list(set(op.control_inputs)
- set(internal_control_inputs))
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
return internal_control_inputs, external_control_inputs
# pylint: enable=protected-access
def AddInnerOp(self, op):
"""Notifies a scope about an operator added to an inner scope."""
if self._outer_context:
self._outer_context.AddInnerOp(op)
def GetControlPivot(self):
"""Returns the pivot node for this context, or None."""
return None
def IsWhileContext(self):
return False
def IsCondContext(self):
return False
def IsXLAContext(self):
return False
def __str__(self):
return self.name
class CondContext(ControlFlowContext):
"""The context for the conditional construct."""
def __init__(self,
pred=None,
pivot=None,
branch=None,
name="cond_text",
context_def=None,
import_scope=None):
"""Creates a `CondContext`.
Args:
pred: The `boolean` tensor for the conditional predicate.
pivot: The predicate tensor in this branch.
branch: 0 or 1 representing this branch.
name: Name of the `CondContext` python object.
context_def: Optional `ContextDef` protocol buffer to initialize the
`CondContext` object from.
import_scope: Optional `string`. Name scope to add. Only used when
initialing from protocol buffer.
"""
self._name = ops.get_default_graph().unique_name(name)
if context_def:
self._init_from_proto(context_def, import_scope=import_scope)
else:
# Initializes the default fields.
ControlFlowContext.__init__(self)
self._pred = pred # The boolean tensor for the cond predicate
self._pivot = pivot # The predicate tensor in this branch
self._branch = branch # 0 or 1 representing this branch
# Values considered to have been already seen in this context. pred is not
# included in this context.
self._values.add(pred.name)
self._external_values[pred.name] = pred
self._values.add(pivot.name)
pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `CondContext` from protocol buffer.
Args:
context_def: `CondContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(context_def, control_flow_pb2.CondContextDef)
# Create from context_def.
g = ops.get_default_graph()
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
self._pred = g.as_graph_element(
ops.prepend_name_scope(context_def.pred_name, import_scope))
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
self._branch = context_def.branch
super(CondContext, self).__init__(values_def=context_def.values_def,
import_scope=import_scope)
@property
def pred(self):
return self._pred
@property
def pivot(self):
return self._pivot
@property
def branch(self):
return self._branch
@property
def grad_state(self):
if self.GetWhileContext():
return self.GetWhileContext().grad_state
return None
@property
def back_prop(self):
if self.GetWhileContext():
self.GetWhileContext().back_prop
return False
def GetControlPivot(self):
return self._pivot
def to_proto(self, export_scope=None):
"""Converts a `CondContext` to a `CondContextDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `CondContextDef` protocol buffer.
"""
if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.CondContextDef()
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.pred_name = ops.strip_name_scope(self._pred.name,
export_scope)
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
export_scope)
context_def.branch = self._branch
context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
export_scope))
for nested in self._nested_contexts:
nested_def = context_def.nested_contexts.add()
nested.to_control_flow_context_def(nested_def)
return context_def
else:
return None
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `CondContext` object created from `context_def`."""
ret = CondContext(context_def=context_def,
import_scope=import_scope)
ret.Enter()
for nested_def in context_def.nested_contexts:
from_control_flow_context_def(nested_def, import_scope=import_scope)
ret.Exit()
return ret
def to_control_flow_context_def(self, context_def, export_scope=None):
context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context. This is needed in
# particular for nested conds.
result = self._external_values.get(val.name)
result = val if result is None else result
else:
result = val
self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
self._values.add(result.name)
self._external_values[result.name] = result
with ops.control_dependencies(None):
result = _SwitchRefOrTensor(result, self._pred)[self._branch]
if self._outer_context:
self._outer_context.AddInnerOp(result.op)
result.op.graph.prevent_fetching(result.op)
# pylint: disable=protected-access
result.op._set_control_flow_context(self)
# pylint: enable=protected-access
self._values.add(result.name)
self._external_values[val.name] = result
return result
def AddOp(self, op):
self._AddOpInternal(op)
def _AddOpInternal(self, op):
"""Add `op` to the current context."""
if not op.inputs:
# If we're in a while loop, remove any control inputs from outside the
# loop.
self._RemoveExternalControlEdges(op)
if not any(util.OpInContext(input_op, self)
for input_op in op.control_inputs):
# pylint: disable=protected-access
op._add_control_input(self._pivot.op)
# pylint: enable=protected-access
else:
# Make each input to 'op' available in this CondContext. If an input is
# already part of this context there's nothing to do, but if it's
# external, AddValue() will handle adding the appropriate Switch node and
# other bookkeeping.
for index in range(len(op.inputs)):
x = op.inputs[index]
if op.type == "Merge" and x.op.type == "NextIteration":
# Edge case: if we're importing a while loop inside this CondContext,
# AddValue() will not correctly handle the NextIteration inputs to
# Merge node. The problem is that the NextIteration should also be
# part of this context, but if we're importing it won't have been
# processed and added to the context yet, so AddValue() will try to
# add a Switch which results in an invalid graph. Instead, we use the
# NextIteration input as-is here, and it will eventually be added to
# the context via AddOp().
real_x = x
else:
real_x = self.AddValue(x)
if real_x != x:
# pylint: disable=protected-access
op._update_input(index, real_x)
# pylint: enable=protected-access
# Remove any external control dependency on this op.
self._RemoveExternalControlEdges(op)
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
op._add_control_input(self._pivot.op)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
output_names = [x.name for x in op.outputs]
ctxt = self
while ctxt is not None:
# pylint: disable=protected-access
ctxt._values.update(output_names)
ctxt = ctxt._outer_context
# pylint: enable=protected-access
if self._outer_context or not util.IsLoopExit(op):
op.graph.prevent_fetching(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
def _ProcessOutputTensor(self, val):
"""Process an output tensor of a conditional branch."""
real_val = val
if val.name not in self._values:
# Handle the special case of lambda: x
self._values.add(val.name)
if self._outer_context:
real_val = self._outer_context.AddValue(val)
self._values.add(real_val.name)
self._external_values[real_val.name] = real_val
real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
self._external_values[val.name] = real_val
else:
external_val = self._external_values.get(val.name)
if external_val is not None:
real_val = external_val
return real_val
def _BuildCondTensor(self, v):
if isinstance(v, ops.Operation):
# Use pivot as the proxy for this op.
return with_dependencies([v], self._pivot)
elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
values = self._ProcessOutputTensor(v.values)
indices = self._ProcessOutputTensor(v.indices)
if isinstance(v, ops.IndexedSlices):
dense_shape = v.dense_shape
if dense_shape is not None:
dense_shape = self._ProcessOutputTensor(dense_shape)
return ops.IndexedSlices(values, indices, dense_shape)
else:
dense_shape = self._ProcessOutputTensor(v.dense_shape)
return sparse_tensor.SparseTensor(indices, values, dense_shape)
else:
v = nest.map_structure(_convert_tensorarray_to_flow, v)
return self._ProcessOutputTensor(ops.convert_to_tensor(v))
def BuildCondBranch(self, fn):
"""Add the subgraph defined by fn() to the graph."""
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
original_result = fn()
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
else:
original_result = nest.map_structure(array_ops.identity,
original_result)
if original_result is None:
return None, None
result = nest.map_structure(self._BuildCondTensor, original_result)
if not isinstance(result, (list, _basetuple)):
result = [result]
return original_result, result
def IsCondContext(self):
return True
def _UnpackIfSingleton(res):
if isinstance(res, (list, _basetuple)) and len(res) == 1:
return res[0]
else:
return res
# pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args
@tf_export("cond")
@deprecation.deprecated_args(
None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
def cond(pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
**WARNING**: Any Tensors or Operations created outside of `true_fn` and
`false_fn` will be executed regardless of which branch is selected at runtime.
Although this behavior is consistent with the dataflow model of TensorFlow,
it has frequently surprised users who expected a lazier semantics.
Consider the following simple program:
```python
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
```
If `x < y`, the `tf.add` operation will be executed and `tf.square`
operation will not be executed. Since `z` is needed for at least one
branch of the `cond`, the `tf.multiply` operation is always executed,
unconditionally.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
stitches together the graph fragments created during the `true_fn` and
`false_fn` calls with some additional graph nodes to ensure that the right
branch gets executed depending on the value of `pred`.
`tf.cond` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
This behavior is disabled by passing `strict=True`.
Args:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables 'strict' mode; see above.
name: Optional name prefix for the returned tensors.
Returns:
Tensors returned by the call to either `true_fn` or `false_fn`. If the
callables return a singleton list, the element is extracted from the list.
Raises:
TypeError: if `true_fn` or `false_fn` is not callable.
ValueError: if `true_fn` and `false_fn` do not return the same number of
tensors, or return tensors of different types.
Example:
```python
x = tf.constant(2)
y = tf.constant(5)
def f1(): return tf.multiply(x, 17)
def f2(): return tf.add(y, 23)
r = tf.cond(tf.less(x, y), f1, f2)
# r is set to f1().
# Operations in f2 (e.g., tf.add) are not executed.
```
"""
if ENABLE_COND_V2 and not context.executing_eagerly():
return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
# backwards-compatibility. This check exists so that we can convert back to
# having them be positional arguments.
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
# `fn1` and `fn2` are deleted.
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")
if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
p_2, p_1 = switch(pred, pred)
pivot_1 = array_ops.identity(p_1, name="switch_t")
pivot_2 = array_ops.identity(p_2, name="switch_f")
pred = array_ops.identity(pred, name="pred_id")
# Disable the fetching of tensors that are only on one branch of cond.
for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
tensor.op.graph.prevent_fetching(tensor.op)
# Build the graph for the true branch in a new context.
context_t = CondContext(pred, pivot_1, branch=1)
try:
context_t.Enter()
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
if orig_res_t is None:
raise ValueError("true_fn must have a return value.")
context_t.ExitResult(res_t)
finally:
context_t.Exit()
# Build the graph for the false branch in a new context.
context_f = CondContext(pred, pivot_2, branch=0)
try:
context_f.Enter()
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
if orig_res_f is None:
raise ValueError("false_fn must have a return value.")
context_f.ExitResult(res_f)
finally:
context_f.Exit()
if not strict:
orig_res_t = _UnpackIfSingleton(orig_res_t)
orig_res_f = _UnpackIfSingleton(orig_res_f)
# Check that the return values of the two branches have the same structure.
try:
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of true_fn and false_fn: {}".format(e))
# Add the final merge to the graph.
if not res_t:
raise ValueError("true_fn and false_fn must return at least one result.")
res_t_flat = nest.flatten(res_t)
res_f_flat = nest.flatten(res_f)
for x, y in zip(res_t_flat, res_f_flat):
assert ((isinstance(x, ops.IndexedSlices) and
isinstance(y, ops.IndexedSlices)) or
(isinstance(x, sparse_tensor.SparseTensor) and
isinstance(y, sparse_tensor.SparseTensor)) or
(isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
val_x = x if isinstance(x, ops.Tensor) else x.values
val_y = y if isinstance(y, ops.Tensor) else y.values
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
raise ValueError(
"Outputs of true_fn and false_fn must have the same type: %s, %s" %
(val_x.dtype.name, val_y.dtype.name))
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
# Only add non-nested conds to the collection. Any nested control flow will
# be encapsulated in the root context.
assert context_t.outer_context == context_f.outer_context
if context_t.outer_context is None:
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
# Singleton lists and tuples are automatically unpacked if strict == False.
if not strict:
merges = _UnpackIfSingleton(merges)
return merges
# pylint: enable=g-doc-args
# pylint: enable=redefined-outer-name
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
while t.op.inputs:
t = t.op.inputs[0]
return tensor_shape.TensorShape(t.op.get_attr("shape"))
return array_ops.shape_internal(t, optimize=False)
# TODO(yuanbyu): Consider having a unified notion of context for
# not only conditionals and loops but also control dependency and
# subgraphs.
class WhileContext(ControlFlowContext):
"""The context for the loop construct."""
def __init__(self,
maximum_iterations=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name="while_context",
grad_state=None,
context_def=None,
import_scope=None):
""""Creates a `WhileContext`.
Args:
maximum_iterations: Optional upper bound on number of loop iterations.
parallel_iterations: The number of iterations allowed to run in parallel.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
grad_state: The gradient loop state.
context_def: Optional `WhileContextDef` protocol buffer to initialize
the `Whilecontext` python object from.
import_scope: Optional `string`. Name scope to add. Only used when
initialing from protocol buffer.
"""
if context_def:
self._init_from_proto(context_def, import_scope=import_scope)
else:
ControlFlowContext.__init__(self)
self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
swap_memory, name)
# The gradient loop state.
self._grad_state = grad_state
def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
swap_memory, name):
"""Creates a new `WhileContext` from arguments.
Args:
maximum_iterations: Optional upper bound on number of loop iterations.
parallel_iterations: The number of iterations allowed to run in parallel.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
Raises:
ValueError: If `parallel_iterations` has invalid value.
"""
if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
raise ValueError("`parallel_iterations` must be a positive integer: "
"%s" % parallel_iterations)
self._name = ops.get_default_graph().unique_name(name)
self._maximum_iterations = maximum_iterations
self._parallel_iterations = parallel_iterations
self._back_prop = back_prop
self._swap_memory = swap_memory
# We use this node to control constants created by the pred lambda.
self._pivot_for_pred = None
# We use this node to control constants created by the body lambda.
self._pivot_for_body = None
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation
self._pivot = None
# The list of exit tensors for loop variables.
self._loop_exits = []
# The list of enter tensors for loop variables.
self._loop_enters = []
self._graph = ops.get_default_graph()
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `WhileContext` from protocol buffer.
Args:
context_def: `WhileContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
"""
assert isinstance(context_def, control_flow_pb2.WhileContextDef)
# Create from context_def.
g = ops.get_default_graph()
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
if context_def.maximum_iterations_name:
self._maximum_iterations = g.as_graph_element(
ops.prepend_name_scope(context_def.maximum_iterations_name,
import_scope))
else:
self._maximum_iterations = None
self._parallel_iterations = context_def.parallel_iterations
self._back_prop = context_def.back_prop
self._swap_memory = context_def.swap_memory
self._pivot_for_pred = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
# We use this node to control constants created by the body lambda.
self._pivot_for_body = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation.
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
# The list of exit tensors for loop variables.
self._loop_exits = [
g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
for exit_name in context_def.loop_exit_names
]
# The list of enter tensors for loop variables.
self._loop_enters = [
g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
for enter_name in context_def.loop_enter_names
]
super(WhileContext, self).__init__(
values_def=context_def.values_def, import_scope=import_scope)
# import_scope causes self.name to be different from the original serialized
# context's name. Rewrite "frame_name" attrs with the new name.
if import_scope:
for tensor_name in self._values:
op = g.as_graph_element(tensor_name).op
if util.IsLoopEnter(op):
# pylint: disable=protected-access
op._set_attr("frame_name",
attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
# pylint: enable=protected-access
self._graph = ops.get_default_graph()
@property
def maximum_iterations(self):
"""The maximum number of iterations that will be executed."""
return self._maximum_iterations
@property
def parallel_iterations(self):
"""The number of iterations allowed to run in parallel."""
return self._parallel_iterations
@property
def back_prop(self):
"""True iff backprop is enabled for this while loop."""
return self._back_prop
@property
def swap_memory(self):
"""True iff GPU-CPU memory swap is enabled for this while loop."""
return self._swap_memory
@property
def pivot(self):
"""The boolean tensor representing the loop termination condition."""
return self._pivot
@property
def loop_enters(self):
"""The list of enter tensors for loop variables."""
return self._loop_enters
@property
def loop_exits(self):
"""The list of exit tensors for loop variables."""
return self._loop_exits
@property
def grad_state(self):
"""The gradient loop state."""
return self._grad_state
def to_proto(self, export_scope=None):
"""Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `WhileContextDef` protocol buffer.
"""
if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.WhileContextDef()
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.parallel_iterations = self._parallel_iterations
if self._maximum_iterations is not None:
context_def.maximum_iterations_name = ops.strip_name_scope(
self._maximum_iterations.name, export_scope)
context_def.back_prop = self._back_prop
context_def.swap_memory = self._swap_memory
context_def.pivot_for_pred_name = ops.strip_name_scope(
self._pivot_for_pred.name, export_scope)
context_def.pivot_for_body_name = ops.strip_name_scope(
self._pivot_for_body.name, export_scope)
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
export_scope)
context_def.loop_exit_names.extend([
ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
])
context_def.loop_enter_names.extend([
ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
])
context_def.values_def.MergeFrom(
super(WhileContext, self)._to_values_def(
export_scope=export_scope))
for nested in self._nested_contexts:
nested_def = context_def.nested_contexts.add()
nested.to_control_flow_context_def(nested_def)
return context_def
else:
return None
def to_control_flow_context_def(self, context_def, export_scope=None):
context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `WhileContext` object created from `context_def`.
Args:
context_def: A `WhileContextDef` protocol buffer.
import_scope: Optional `string`. Name scope to add.
Returns:
A `WhileContext` Python object.
"""
ret = WhileContext(context_def=context_def,
import_scope=import_scope)
ret.Enter()
for nested_def in context_def.nested_contexts:
from_control_flow_context_def(nested_def, import_scope=import_scope)
ret.Exit()
return ret
def GetWhileContext(self):
return self
def GetControlPivot(self):
if self._pivot_for_body is not None:
return self._pivot_for_body
return self._pivot_for_pred
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
result = val
new_value = val.name not in self._values
# Don't treat ops in this context as new values. Usually all known values
# are in self._values, except when we're importing a while loop inside this
# WhileContext. Since there's a cycle in this case, `val` may be part of the
# imported while loop but not yet processed by this context and added to
# self._values in _AddOpInternal. We only want to process external input
# tensors to the while loop here.
new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access
if new_value:
self._values.add(val.name)
# If we are in a grad context and val is from its forward context,
# use GetRealValue(), which adds the logic to save the history of
# val in forward.
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
if grad_ctxt:
grad_ctxt = grad_ctxt.GetWhileContext()
if grad_ctxt.grad_state:
forward_ctxt = _GetWhileContext(val.op)
if util.IsLoopExit(val.op):
forward_ctxt = forward_ctxt.outer_context
if forward_ctxt:
forward_ctxt = forward_ctxt.GetWhileContext()
if forward_ctxt == grad_ctxt.grad_state.forward_context:
real_val = grad_ctxt.grad_state.GetRealValue(val)
self._external_values[val.name] = real_val
return real_val
if self._outer_context is not None:
result = self._outer_context.AddValue(val)
# Create an Enter to make `result` known to this loop context.
with ops.control_dependencies(None):
enter = _Enter(
result,
self._name,
is_constant=True,
parallel_iterations=self._parallel_iterations)
enter.graph.prevent_feeding(enter)
if self._outer_context:
self._outer_context.AddInnerOp(enter.op)
# Fix the control inputs and control flow context of these enter ops.
self._FixControlInputsAndContext([enter])
# Add `enter` in this context.
self._values.add(enter.name)
self._external_values[val.name] = enter
result = enter
else:
actual_val = self._external_values.get(val.name)
if actual_val is not None:
result = actual_val
return result
def AddOp(self, op):
"""Add `op` to the current context."""
# For a reduction op, if op is in a grad context and its input is from
# its forward context, moving op to the forward context means we would
# store the tensor after the reduction as opposed to the tensor before
# reduction, and therefore could significantly reduce memory consumption.
# For now, we do this only for a few ops.
if op.type in {"Shape", "Size", "Rank"}:
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
if grad_ctxt:
grad_ctxt = grad_ctxt.GetWhileContext()
if grad_ctxt.grad_state:
op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
op_input_ctxt = op.inputs[0].op._get_control_flow_context()
op._set_control_flow_context(op_input_ctxt)
op_input_ctxt._AddOpInternal(op)
return
self._AddOpInternal(op)
def _AddOpInternal(self, op):
"""Add `op` to the current context.
We move any external control dependencies of the op to the loop pivot, to
ensure they get executed.
"""
if not op.inputs:
# Remove any external control dependency on this op
control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
# Add a control edge from the control pivot to this op.
if not control_inputs:
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
# pylint: enable=protected-access
for x in op.outputs:
self._values.add(x.name)
else:
for index in range(len(op.inputs)):
x = op.inputs[index]
real_x = self.AddValue(x)
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
# Remove any external control dependency on this op.
_, external_inputs = self._RemoveExternalControlEdges(op)
# Add a control dependency to prevent loop invariants from
# enabling ops that should not be executed.
self._MaybeAddControlDependency(op)
for x in op.outputs:
self._values.add(x.name)
if external_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(apassos): fix that
with ops.control_dependencies(None):
self.Enter()
external_inputs = [array_ops.identity(x.outputs[0]).op
for x in external_inputs if x.outputs]
self.Exit()
op._add_control_inputs(external_inputs) # pylint: disable=protected-access
if self._outer_context or not util.IsLoopExit(op):
op.graph.prevent_fetching(op)
for x in op.outputs:
op.graph.prevent_feeding(x)
if self._outer_context:
self._outer_context.AddInnerOp(op)
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""
def _IsOpFree(op):
"""Determines if `op` needs a control dependency."""
if op.control_inputs:
return False
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
return True
# pylint: enable=protected-access
for x in op.inputs:
if not util.IsLoopConstantEnter(x.op):
return False
return True
if _IsOpFree(op):
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
# pylint: enable=protected-access
def AddForwardLoopCounter(self, outer_grad_state):
"""Adds a loop that counts the number of iterations.
This is added to the forward loop at the time when we start to
create the loop for backprop gradient computation. Called in
the outer context of this forward context.
The pseudocode is:
`n = 0; while (_pivot) { n++; }`
Note that a control dependency is added to `n` to ensure the correct
execution order of stack push ops.
Args:
outer_grad_state: The outer grad state. None if not nested.
Returns:
The number of iterations taken by the forward loop and the loop index.
"""
n = constant_op.constant(0, name="f_count")
if outer_grad_state is not None:
# Force the stack pushes of i-th execution of an inner loop to be ordered
# before the pushes of (i+1)-th execution of the same inner loop.
outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
n.op._add_control_input(outer_add_op) # pylint: disable=protected-access
self.Enter()
self.AddName(n.name)
enter_n = _Enter(
n,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="f_count")
self.loop_enters.append(enter_n)
merge_n = merge([enter_n, enter_n])[0]
switch_n = switch(merge_n, self._pivot)
index = math_ops.add(switch_n[1], 1)
next_n = _NextIteration(index)
merge_n.op._update_input(1, next_n)
total_iterations = exit(switch_n[0], name="f_count")
self.loop_exits.append(total_iterations)
self.ExitResult([total_iterations])
self.Exit()
return total_iterations, next_n
def AddBackpropLoopCounter(self, count, outer_grad_state):
"""Add the backprop loop that controls the iterations.
This is added to the backprop loop. It is used to control the loop
termination of the backprop loop. Called in the outer context of
this grad context.
The pseudocode is:
`n = count; while (n >= 1) { n--; }`
Note that a control dependency is added to `final_zero` to ensure the
correct execution order of stack pop ops.
Args:
count: The number of iterations for backprop.
outer_grad_state: The outer grad state. None if not nested.
Returns:
The loop index.
"""
in_separate_functions = count.graph is not ops.get_default_graph()
if in_separate_functions:
# Brings the count into this graph
count = array_ops.identity(count)
else:
# TODO(apassos) XLA expects this constant to be created outside the loop,
# so doing that for now.
one = constant_op.constant(1, name="b_count")
self.Enter()
self.AddName(count.name)
enter_count = _Enter(
count,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="b_count")
self.loop_enters.append(enter_count)
merge_count = merge([enter_count, enter_count])[0]
self._pivot_for_pred = merge_count
if in_separate_functions:
one = constant_op.constant(1, name="b_count")
pred = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(pred, name="b_count")
switch_count = switch(merge_count, self._pivot)
index = math_ops.subtract(switch_count[1], one)
self._pivot_for_body = index
next_count = _NextIteration(index)
merge_count.op._update_input(1, next_count)
final_zero = exit(switch_count[0], name="b_count")
self.loop_exits.append(final_zero)
if outer_grad_state is not None:
# Force the stack pops of i-th execution of an inner loop to be ordered
# before the pops of (i+1)-th execution of the same inner loop.
# pylint: disable=protected-access
outer_grad_state.grad_sync._add_control_input(final_zero.op)
# pylint: enable=protected-access
self.ExitResult([final_zero])
self.Exit()
return next_count
def AddBackpropAccumulator(self, op, grad):
"""Add an accumulation loop for every loop invariant.
This is added to the backprop loop. It is used to accumulate partial
gradients within each loop iteration. Called when in the gradient while
context.
The pseudocode is:
```
acc = 0.0;
while (_pivot) {
acc += grad;
}
```
Args:
op: The Enter op for a loop invariant.
grad: The partial gradient of an iteration for a loop invariant.
Returns:
The gradient for a loop invariant.
"""
self.Exit()
# Create a zeros tensor with the right shape for acc. If we don't
# know the full shape statically, we will have to get the shape
# dynamically from the forward inference. Getting the shape right
# for the zeros is only needed for the base case when the loop exits
# without running any iterations.
shape = grad.get_shape()
if shape.is_fully_defined():
if self.outer_context:
self.outer_context.Enter()
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
if self.outer_context:
self.outer_context.Exit()
else:
value = op.inputs[0]
if (isinstance(self.outer_context, WhileContext) and
self.outer_context.grad_state is not None):
# We are in a nested while loop.
forward_ctxt = self.grad_state.forward_context
forward_ctxt.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
forward_ctxt.outer_context.Exit()
outer_grad_state = self.grad_state.outer_grad_state
history_zeros_shape = outer_grad_state.AddForwardAccumulator(
zeros_shape)
self.outer_context.Enter()
real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_zeros_shape, zeros_shape)
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
if self.outer_context:
self.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
acc = array_ops.zeros(zeros_shape, grad.dtype)
if self.outer_context:
self.outer_context.Exit()
self.Enter()
self.AddName(acc.name)
enter_acc = _Enter(
acc,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
name="b_acc")
self.loop_enters.append(enter_acc)
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
add_acc = math_ops.add(switch_acc_true, grad)
next_acc = _NextIteration(add_acc)
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
result_acc = exit(switch_acc_false, name="b_acc")
self.loop_exits.append(result_acc)
self.ExitResult([result_acc])
return result_acc
def AddBackpropIndexedSlicesAccumulator(self, op, grad):
"""This is used for accumulating gradients that are IndexedSlices.
This is essentially the equivalent of AddBackpropAccumulator but optimized
for things like updating embeddings from within a while loop.
Args:
op: The Enter op for a loop invariant.
grad: The partial gradients represented as an IndexedSlices.
Returns:
The accumulated IndexedSlices gradient of the loop invariant.
"""
values = grad.values
indices = grad.indices
dense_shape = grad.dense_shape
self.Exit()
if self.outer_context:
self.outer_context.Enter()
if values.get_shape().is_fully_defined():
values_shape = tensor_shape.TensorShape(
[tensor_shape.Dimension(1)] + values.get_shape().dims[1:])
if self.outer_context:
self.outer_context.Enter()
values_acc = constant_op.constant(
0, values.dtype, shape=values_shape, name="b_acc")
if self.outer_context:
self.outer_context.Exit()
else:
values_shape = _resource_safe_shape(op.inputs[0])[1:]
values_shape = array_ops.concat([[1], values_shape], 0)
values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
indices_acc = constant_op.constant([0], indices.dtype)
shape_acc = None
if dense_shape is not None:
if dense_shape.get_shape().is_fully_defined():
if self.outer_context:
self.outer_context.Enter()
shape_acc = constant_op.constant(
0, dense_shape.dtype, shape=dense_shape.get_shape())
if self.outer_context:
self.outer_context.Exit()
else:
shape_acc = array_ops.zeros_like(
array_ops.shape_internal(op.inputs[0], optimize=False,
out_type=dense_shape.dtype),
optimize=False)
if self.outer_context:
self.outer_context.Exit()
self.Enter()
self.AddName(values_acc.name)
self.AddName(indices_acc.name)
init_acc = [indices_acc, values_acc]
if shape_acc is not None:
self.AddName(shape_acc.name)
init_acc.append(shape_acc)
# Set use_input_shape=False since the accumulator tensors will grow in
# size. If use_input_shape=True, the _update_input call below will result in
# incompatible shapes.
enter_acc = [
_Enter(
x,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
use_input_shape=False,
name="b_acc") for x in init_acc
]
# Manually set appropriate partial shapes.
enter_acc[0].set_shape([None])
if values_acc.shape.dims is not None:
enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
self.loop_enters.extend(enter_acc)
merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
switch_acc = [switch(x, self._pivot) for x in merge_acc]
# The actual accumulation.
acc_indexed_slices = [
array_ops.concat([xa[1], xv], 0)
for xa, xv in zip(switch_acc[:2], [indices, values])
]
if shape_acc is not None:
# For the shape we just keep the maximum
acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
next_acc = [_NextIteration(x) for x in acc_indexed_slices]
for xm, xn in zip(merge_acc, next_acc):
xm.op._update_input(1, xn) # pylint: disable=protected-access
exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
self.loop_exits.extend(exit_acc)
self.ExitResult(exit_acc)
return ops.IndexedSlices(
indices=exit_acc[0],
values=exit_acc[1],
dense_shape=exit_acc[2] if shape_acc is not None else None)
def _InitializeValues(self, values):
"""Makes the values known to this context."""
self._values = set()
for x in values:
if isinstance(x, ops.Tensor):
self._values.add(x.name)
else:
self._values.add(x.values.name)
self._values.add(x.indices.name)
if isinstance(x, ops.IndexedSlices):
dense_shape = x.dense_shape
elif isinstance(x, sparse_tensor.SparseTensor):
dense_shape = x.dense_shape
else:
raise TypeError("Type %s not supported" % type(x))
if dense_shape is not None:
self._values.add(dense_shape.name)
def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
shape_invariants):
"""Core: Add the loop termination condition and body to the graph."""
flat_loop_vars = nest.flatten(original_loop_vars)
# Let the context know the loop variables so the loop variables
# would be added in the outer contexts properly.
self._InitializeValues(loop_vars)
real_vars = loop_vars
if self._outer_context:
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
with ops.control_dependencies(None):
enter_vars = [
_Enter(
x,
self._name,
is_constant=False,
parallel_iterations=self._parallel_iterations,
use_input_shape=(shape_invariants is None)) for x in real_vars
]
for x in enter_vars:
x.graph.prevent_feeding(x)
if self._outer_context:
self._outer_context.AddInnerOp(x.op)
# Finds the closest enclosing non-None control pivot.
outer_context = self._outer_context
control_pivot = None
while outer_context is not None and control_pivot is None:
control_pivot = outer_context.GetControlPivot()
# pylint: disable=protected-access
outer_context = outer_context._outer_context
# pylint: enable=protected-access
if control_pivot is not None:
for var in enter_vars:
if util.IsLoopConstantEnter(var.op.inputs[0].op):
# pylint: disable=protected-access
var.op._add_control_input(control_pivot.op)
# pylint: enable=protected-access
_SetShapeInvariants(real_vars, enter_vars, shape_invariants)
# Fix the control inputs and control flow context of these enter ops.
self._FixControlInputsAndContext(enter_vars)
self._InitializeValues(enter_vars)
self._loop_enters = enter_vars
merge_vars = [merge([x, x])[0] for x in enter_vars]
self._pivot_for_pred = merge_vars[0]
# Build the graph for pred.
merge_vars_with_tensor_arrays = (
_convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
packed_vars = nest.pack_sequence_as(
structure=original_loop_vars,
flat_sequence=merge_vars_with_tensor_arrays)
c = ops.convert_to_tensor(pred(*packed_vars))
self._pivot = loop_cond(c, name="LoopCond")
switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
# Build the graph for body.
vars_for_body = [_Identity(x[1]) for x in switch_vars]
self._pivot_for_body = vars_for_body[0]
# Convert TensorArray flow variables inside the context back into
# their associated TensorArrays for calling the body.
vars_for_body_with_tensor_arrays = (
_convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
packed_vars_for_body = nest.pack_sequence_as(
structure=original_loop_vars,
flat_sequence=vars_for_body_with_tensor_arrays)
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
body_result = body(*packed_vars_for_body)
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
if not nest.is_sequence(body_result):
body_result = [body_result]
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
def map_fn(x):
# TODO(apassos) figure out how to trigger with tensor arrays as well
if isinstance(x, tensor_array_ops.TensorArray):
return x
return array_ops.identity(x)
body_result = nest.map_structure(map_fn, body_result)
# Compare the structure types of input and output of body.
# For backwards compatibility, the first layer is forced to a list
# during this comparison, because inputs are typically lists and
# outputs of the body are typically tuples.
nest.assert_same_structure(list(packed_vars_for_body), list(body_result))
# Store body_result to keep track of TensorArrays returned by body
original_body_result = body_result
# Convert TensorArrays returned by body into their flow variables
result = nest.map_structure(_convert_tensorarray_to_flow,
nest.flatten(body_result))
result = ops.convert_n_to_tensor_or_indexed_slices(result)
# Add NextIteration and the back edges to complete the loop.
if len(merge_vars) != len(result):
raise ValueError("Number of inputs and outputs of body must match "
"loop_vars: %d, %d" % (len(merge_vars), len(result)))
next_vars = []
for m, v in zip(merge_vars, result):
next_vars.append(_AddNextAndBackEdge(m, v))
# Add the exit ops.
exit_vars = [exit(x[0]) for x in switch_vars]
self._loop_exits = exit_vars
# Exit the loop.
self.ExitResult(exit_vars)
return original_body_result, exit_vars
def BuildLoop(self, pred, body, loop_vars, shape_invariants,
return_same_structure):
"""Add the loop termination condition and body to the graph."""
# Keep original_loop_vars to identify which are TensorArrays
original_loop_vars = loop_vars
# Convert TensorArrays to their flow variables
loop_vars = nest.map_structure(_convert_tensorarray_to_flow,
nest.flatten(loop_vars))
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
try:
self.Enter()
# _BuildLoop calls _update_input in several places. _mutation_lock()
# ensures a Session.run call cannot occur between creating and mutating
# new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
original_body_result, exit_vars = self._BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants)
finally:
self.Exit()
flat_result = nest.flatten(original_body_result)
# Convert TensorArray flow variables outside the context back into
# their associated TensorArrays for returning to caller.
exit_vars_with_tensor_arrays = (
_convert_flows_to_tensorarrays(flat_result, exit_vars))
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensor_arrays)
if return_same_structure:
return packed_exit_vars
else:
return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
# pylint: disable=protected-access
for e in enters:
if isinstance(e, ops.Tensor):
xs = [e]
else:
if not isinstance(e, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(e))
xs = [e.values, e.indices]
shape = e.dense_shape
if shape is not None:
xs.append(shape)
for x in xs:
inp_op = x.op.inputs[0].op
control_inputs = graph._control_dependencies_for_inputs([inp_op])
outer_control_inputs = [
op for op in control_inputs if self._IsInOuterContext(op)
]
x.op._set_control_flow_context(self)
x.op._add_control_inputs(outer_control_inputs)
graph._record_op_seen_by_control_dependencies(x.op)
# pylint: enable=protected-access
def IsWhileContext(self):
return True
# pylint: disable=redefined-outer-name
@tf_export("while_loop")
def while_loop(cond,
body,
loop_vars,
shape_invariants=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None,
maximum_iterations=None,
return_same_structure=False):
"""Repeat `body` while the condition `cond` is true.
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
returning a (possibly nested) tuple, namedtuple or list of tensors of the same
arity (length and structure) and types as `loop_vars`. `loop_vars` is a
(possibly nested) tuple, namedtuple or list of tensors that is passed to both
`cond` and `body`. `cond` and `body` both take as many arguments as there are
`loop_vars`.
In addition to regular Tensors or IndexedSlices, the body may accept and
return TensorArray objects. The flows of the TensorArray objects will
be appropriately forwarded between loops and during gradient calculations.
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
stitches together the graph fragments created during the `cond` and `body`
calls with some additional graph nodes to create the graph flow that
repeats `body` until `cond` returns false.
For correctness, `tf.while_loop()` strictly enforces shape invariants for
the loop variables. A shape invariant is a (possibly partial) shape that
is unchanged across the iterations of the loop. An error will be raised
if the shape of a loop variable after an iteration is determined to be more
general than or incompatible with its shape invariant. For example, a shape
of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
compatible with [11, 17]. By default (if the argument `shape_invariants` is
not specified), it is assumed that the initial shape of each tensor in
`loop_vars` is the same in every iteration. The `shape_invariants` argument
allows the caller to specify a less specific shape invariant for each loop
variable, which is needed if the shape varies between iterations. The
`tf.Tensor.set_shape`
function may also be used in the `body` function to indicate that
the output loop variable has a particular shape. The shape invariant for
SparseTensor and IndexedSlices are treated specially as follows:
a) If a loop variable is a SparseTensor, the shape invariant must be
TensorShape([r]) where r is the rank of the dense tensor represented
by the sparse tensor. It means the shapes of the three tensors of the
SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
is the shape of the SparseTensor.dense_shape property. It must be the shape of
a vector.
b) If a loop variable is an IndexedSlices, the shape invariant must be
a shape invariant of the values tensor of the IndexedSlices. It means
the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
[shape.ndims]).
`while_loop` implements non-strict semantics, enabling multiple iterations
to run in parallel. The maximum number of parallel iterations can be
controlled by `parallel_iterations`, which gives users some control over
memory consumption and execution order. For correct programs, `while_loop`
should return the same result for any parallel_iterations > 0.
For training, TensorFlow stores the tensors that are produced in the
forward inference and are needed in back propagation. These tensors are a
main source of memory consumption and often cause OOM errors when training
on GPUs. When the flag swap_memory is true, we swap out these tensors from
GPU to CPU. This for example allows us to train RNN models with very long
sequences and large batches.
Args:
cond: A callable that represents the termination condition of the loop.
body: A callable that represents the loop body.
loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
`Tensor`, and `TensorArray` objects.
shape_invariants: The shape invariants for the loop variables.
parallel_iterations: The number of iterations allowed to run in parallel.
It must be a positive integer.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
maximum_iterations: Optional maximum number of iterations of the while loop
to run. If provided, the `cond` output is AND-ed with an additional
condition ensuring the number of iterations executed is no greater than
`maximum_iterations`.
return_same_structure: If True, output has same structure as `loop_vars`. If
eager execution is enabled, this is ignored (and always treated as True).
Returns:
The output tensors for the loop variables after the loop.
If `return_same_structure` is True, the return value has the same
structure as `loop_vars`.
If `return_same_structure` is False, the return value is a Tensor,
TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
otherwise.
Raises:
TypeError: if `cond` or `body` is not callable.
ValueError: if `loop_vars` is empty.
Example:
```python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
```
Example with nesting and a namedtuple:
```python
import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)
```
Example using shape_invariants:
```python
i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
c, b, loop_vars=[i0, m0],
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
```
Example which demonstrates non-strict semantics: In the following
example, the final value of the counter `i` does not depend on `x`. So
the `while_loop` can increment the counter parallel to updates of `x`.
However, because the loop counter at one loop iteration depends
on the value at the previous iteration, the loop counter itself cannot
be incremented in parallel. Hence if we just want the final value of the
counter (which we print on the line `print(sess.run(i))`), then
`x` will never be incremented, but the counter will be updated on a
single thread. Conversely, if we want the value of the output (which we
print on the line `print(sess.run(out).shape)`), then the counter may be
incremented on its own thread, while `x` can be incremented in
parallel on a separate thread. In the extreme case, it is conceivable
that the thread incrementing the counter runs until completion before
`x` is incremented even a single time. The only thing that can never
happen is that the thread updating `x` can never get ahead of the
counter thread because the thread incrementing `x` depends on the value
of the counter.
```python
import tensorflow as tf
n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x: i < n
b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.Session() as sess:
print(sess.run(i)) # prints [0] ... [9999]
# The following line may increment the counter and x in parallel.
# The counter thread may get ahead of the other thread, but not the
# other way around. So you may see things like
# [9996] x:[9987]
# meaning that the counter thread is on iteration 9996,
# while the other thread is on iteration 9987
print(sess.run(out).shape)
```
"""
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
if not callable(cond):
raise TypeError("cond must be callable.")
if not callable(body):
raise TypeError("body must be callable.")
if parallel_iterations < 1:
raise TypeError("parallel_iterations must be a positive integer.")
if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
maximum_iterations, name="maximum_iterations")
if maximum_iterations.shape.ndims != 0:
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
maximum_iterations.shape)
counter = constant_op.constant(
0, dtype=maximum_iterations.dtype, name="iteration_counter")
orig_cond = cond
orig_body = body
if len(loop_vars) == 1:
loop_vars = (counter, loop_vars[0])
cond = lambda i, lv: ( # pylint: disable=g-long-lambda
math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
body = lambda i, lv: (i + 1, orig_body(lv))
else:
loop_vars = (counter, loop_vars)
cond = lambda i, lv: ( # pylint: disable=g-long-lambda
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
body = lambda i, lv: (i + 1, orig_body(*lv))
if context.executing_eagerly():
try_to_pack = len(loop_vars) == 1
packed = False # whether the body result was packed into a 1-item tuple
while cond(*loop_vars):
loop_vars = body(*loop_vars)
if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
packed = True
loop_vars = (loop_vars,)
if maximum_iterations is not None:
return loop_vars[1]
else:
return loop_vars[0] if packed else loop_vars
if shape_invariants is not None:
if maximum_iterations is not None:
shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
nest.assert_same_structure(loop_vars, shape_invariants)
loop_context = WhileContext(
maximum_iterations=maximum_iterations,
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
# Only add non-nested loops to the collection. Any nested control flow will
# be encapsulated in the root context.
if loop_context.outer_context is None:
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
return_same_structure)
if maximum_iterations is not None:
return result[1]
else:
return result
# pylint: enable=redefined-outer-name
def _AsTensorList(x, p):
"""Return x as a list of Tensors or IndexedSlices.
For entries of `x` that are Operations, this returns an Identity of `p`
with a dependency on the operation.
Args:
x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
p: A Tensor to return for entries in `x` that are Operations.
Returns:
A list of Tensors or IndexedSlices.
"""
if not isinstance(x, (list, _basetuple)):
x = [x]
l = []
for v in x:
if isinstance(v, ops.Operation):
v = with_dependencies([v], p)
v = ops.convert_to_tensor_or_indexed_slices(v)
if isinstance(v, ops.Tensor):
l.append(array_ops.identity(v))
else:
l.append(
ops.IndexedSlices(
array_ops.identity(v.values), array_ops.identity(v.indices)))
return l
def _CheckResults(a, b):
assert len(a) == len(b), (
"Values returned by a() and b() must have the same length.")
for x, y in zip(a, b):
assert x.dtype == y.dtype, (
"Values returned by a() [%s] and b() [%s] must have "
"the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
def with_dependencies(dependencies, output_tensor, name=None):
"""Produces the content of `output_tensor` only after `dependencies`.
In some cases, a user may want the output of an operation to be
consumed externally only after some other dependencies have run
first. This function ensures returns `output_tensor`, but only after all
operations in `dependencies` have run. Note that this means that there is
no guarantee that `output_tensor` will be evaluated after any `dependencies`
have run.
See also `tf.tuple` and `tf.group`.
Args:
dependencies: Iterable of operations to run before this op finishes.
output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
name: (Optional) A name for this operation.
Returns:
Same as `output_tensor`.
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
"""
if context.executing_eagerly():
return output_tensor
with ops.name_scope(name, "control_dependency",
list(dependencies) + [output_tensor]) as name:
with ops.colocate_with(output_tensor):
with ops.control_dependencies(dependencies):
output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
if isinstance(output_tensor, ops.Tensor):
return _Identity(output_tensor, name=name)
else:
return ops.IndexedSlices(
_Identity(output_tensor.values, name=name), output_tensor.indices,
output_tensor.dense_shape)
def _GroupControlDeps(dev, deps, name=None):
with ops.control_dependencies(deps):
if dev is None:
return no_op(name=name)
else:
with ops.device(dev):
return no_op(name=name)
# TODO(touts): Accept "inputs" as a list.
@tf_export("group")
def group(*inputs, **kwargs):
"""Create an op that groups multiple operations.
When this op finishes, all ops in `inputs` have finished. This op has no
output.
See also `tf.tuple` and
`tf.control_dependencies`.
Args:
*inputs: Zero or more tensors to group.
name: A name for this operation (optional).
Returns:
An Operation that executes all its inputs.
Raises:
ValueError: If an unknown keyword argument is provided.
"""
if context.executing_eagerly():
return None
name = kwargs.pop("name", None)
if kwargs:
raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
with ops.name_scope(name, "group_deps", inputs) as name:
# Grouping no inputs means do nothing
if not inputs:
return no_op(name=name)
# Sorts *inputs according to their devices.
ops_on_device = {} # device -> operations specified on the device.
for inp in nest.flatten(inputs):
if not hasattr(inp, "device"):
raise TypeError("Expected tf.group() expected Tensor arguments not "
"'%s' with type '%s'" % (inp, type(inp)))
dev = inp.device
if dev in ops_on_device:
ops_on_device[dev].append(inp)
else:
ops_on_device[dev] = [inp]
if len(ops_on_device) == 1:
# 1-level tree. The root node is the returned NoOp node.
(dev, deps), = ops_on_device.items()
return _GroupControlDeps(dev, deps, name=name)
# 2-level tree. The root node is the returned NoOp node.
# deps contains 1 NoOp node for each device.
deps = []
def device_key(dev):
"""A sort key that allows None to be compared to strings."""
return "" if dev is None else dev
for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
with ops.control_dependencies(deps):
return no_op(name=name)
@tf_export("tuple")
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
"""Group tensors together.
This creates a tuple of tensors with the same values as the `tensors`
argument, except that the value of each tensor is only returned after the
values of all tensors have been computed.
`control_inputs` contains additional ops that have to finish before this op
finishes, but whose outputs are not returned.
This can be used as a "join" mechanism for parallel computations: all the
argument tensors can be computed in parallel, but the values of any tensor
returned by `tuple` are only available after all the parallel computations
are done.
See also `tf.group` and
`tf.control_dependencies`.
Args:
tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
name: (optional) A name to use as a `name_scope` for the operation.
control_inputs: List of additional ops to finish before returning.
Returns:
Same as `tensors`.
Raises:
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
objects.
"""
if context.executing_eagerly():
return tensors
with ops.name_scope(name, "tuple", tensors) as name:
tensors = [t if (isinstance(t, ops.Operation)
or tensor_util.is_tensor(t)
or t is None)
else ops.convert_to_tensor(t) for t in tensors]
gating_ops = [t if isinstance(t, ops.Operation) else t.op for t in tensors
if t is not None]
if control_inputs:
for c in control_inputs:
if isinstance(c, ops.Tensor):
c = c.op
elif not isinstance(c, ops.Operation):
raise TypeError("Control input must be Operation or Tensor: %s" % c)
gating_ops.append(c)
# Note that in order to ensure ordering in the pbtxt, we must take care to
# ensure the order here.
gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
if not gating_ops:
raise ValueError("Must have at least one Tensor: %s" % tensors)
gate = group(*gating_ops)
tpl = []
for t in tensors:
if tensor_util.is_tensor(t):
tpl.append(with_dependencies([gate], t))
elif isinstance(t, ops.Operation):
with ops.control_dependencies([gate]):
tpl.append(group(t))
else:
tpl.append(None)
return tpl
def _assert_at_most_n_true(predicates, n, msg):
"""Returns an Assert op that checks that at most n predicates are True.
Args:
predicates: list of bool scalar tensors.
n: maximum number of true predicates allowed.
msg: Error message.
"""
preds_c = array_ops.stack(predicates, name="preds_c")
num_true_conditions = math_ops.reduce_sum(
math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
condition = math_ops.less_equal(num_true_conditions,
constant_op.constant(n, name="n_true_conds"))
preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
error_msg = [
"%s: more than %d conditions (%s) evaluated as True:" %
(msg, n, preds_names), preds_c
]
return Assert(condition, data=error_msg, summarize=len(predicates))
def _case_create_default_action(predicates, actions):
"""Creates default action for a list of actions and their predicates.
It uses the input actions to select an arbitrary as default and makes sure
that corresponding predicates have valid values.
Args:
predicates: a list of bool scalar tensors
actions: a list of callable objects which return tensors.
Returns:
a callable
"""
k = len(predicates) - 1 # could pick any
predicate, action = predicates[k], actions[k]
other_predicates, other_actions = predicates[:k], actions[:k]
def default_action():
others_msg = ("Implementation error: "
"selected default action #%d was called, but some of other "
"predicates are True: " % k)
default_msg = ("Input error: "
"None of conditions evaluated as True:",
array_ops.stack(predicates, name="preds_c"))
with ops.control_dependencies([
_assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
Assert(predicate, data=default_msg)
]):
return action()
return default_action, other_predicates, other_actions
def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
allow_python_preds):
"""Verifies input arguments for the case function.
Args:
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor,
and a callable which returns a list of tensors.
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
name: A name for the case operation.
allow_python_preds: if true, pred_fn_pairs may contain Python bools in
addition to boolean Tensors
Raises:
TypeError: If `pred_fn_pairs` is not a list/dictionary.
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
Returns:
a tuple <list of scalar bool tensors, list of callables>.
"""
if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
raise TypeError("fns must be a list, tuple, or dict")
if isinstance(pred_fn_pairs, collections.OrderedDict):
pred_fn_pairs = pred_fn_pairs.items()
elif isinstance(pred_fn_pairs, dict):
pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name)
if not exclusive:
logging.warn("%s: An unordered dictionary of predicate/fn pairs was "
"provided, but exclusive=False. The order of conditional "
"tests is deterministic but not guaranteed.", name)
for pred_fn_pair in pred_fn_pairs:
if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
pred, fn = pred_fn_pair
if isinstance(pred, ops.Tensor):
if pred.dtype != dtypes.bool:
raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
elif not allow_python_preds:
raise TypeError("pred must be a Tensor, got: %s" % pred)
elif not isinstance(pred, bool):
raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
if not callable(fn):
raise TypeError("fn for pred %s must be callable." % pred.name)
predicates, actions = zip(*pred_fn_pairs)
return predicates, actions
def _case_helper(cond_fn, pred_fn_pairs, default,
exclusive, name, allow_python_preds=False, **cond_kwargs):
"""Implementation of case that allows for different cond functions.
Args:
cond_fn: method that has signature and semantics of `cond` above.
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
callable which returns a list of tensors.
default: Optional callable that returns a list of tensors.
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
name: A name for this operation (optional).
allow_python_preds: if true, pred_fn_pairs may contain Python bools in
addition to boolean Tensors
**cond_kwargs: keyword arguments that will be passed to `cond_fn`.
Returns:
The tensors returned by the first pair whose predicate evaluated to True, or
those returned by `default` if none does.
Raises:
TypeError: If `pred_fn_pairs` is not a list/dictionary.
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
"""
predicates, actions = _case_verify_and_canonicalize_args(
pred_fn_pairs, exclusive, name, allow_python_preds)
with ops.name_scope(name, "case", [predicates]):
if default is None:
default, predicates, actions = _case_create_default_action(
predicates, actions)
fn = default
# To eval conditions in direct order we create nested conditions in reverse:
# cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
for predicate, action in reversed(list(zip(predicates, actions))):
fn = functools.partial(
cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
if exclusive:
with ops.control_dependencies([
_assert_at_most_n_true(
predicates, n=1, msg="Input error: exclusive=True")
]):
return fn()
else:
return fn()
@tf_export("case")
def case(pred_fn_pairs,
default=None,
exclusive=False,
strict=False,
name="case"):
"""Create a case operation.
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
Each pair contains a boolean scalar tensor and a python callable that
creates the tensors to be returned if the boolean evaluates to True.
`default` is a callable generating a list of tensors. All the callables
in `pred_fn_pairs` as well as `default` (if provided) should return the same
number and types of tensors.
If `exclusive==True`, all predicates are evaluated, and an exception is
thrown if more than one of the predicates evaluates to `True`.
If `exclusive==False`, execution stops at the first predicate which
evaluates to True, and the tensors generated by the corresponding function
are returned immediately. If none of the predicates evaluate to True, this
operation returns the tensors generated by `default`.
`tf.case` supports nested structures as implemented in
`tensorflow.python.util.nest`. All of the callables must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
a callable, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`.
If an unordered dictionary is used for `pred_fn_pairs`, the order of the
conditional tests is not guaranteed. However, the order is guaranteed to be
deterministic, so that variables created in conditional branches are created
in fixed order across runs.
**Example 1:**
Pseudocode:
```
if (x < y) return 17;
else return 23;
```
Expressions:
```python
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = case([(tf.less(x, y), f1)], default=f2)
```
**Example 2:**
Pseudocode:
```
if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;
```
Expressions:
```python
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
```
Args:
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
callable which returns a list of tensors.
default: Optional callable that returns a list of tensors.
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
strict: A boolean that enables/disables 'strict' mode; see above.
name: A name for this operation (optional).
Returns:
The tensors returned by the first pair whose predicate evaluated to True, or
those returned by `default` if none does.
Raises:
TypeError: If `pred_fn_pairs` is not a list/dictionary.
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
"""
return _case_helper(cond, pred_fn_pairs, default, exclusive, name,
allow_python_preds=False, strict=strict)
class XLAControlFlowContext(ControlFlowContext):
"""Base class for XLA and TPU control flow contexts."""
def __init__(self):
super(XLAControlFlowContext, self).__init__()
self._name = "XLAControlFlowContext"
def IsXLAContext(self):
return True
def AddOp(self, _):
pass
def AddValue(self, x):
return x
def from_control_flow_context_def(context_def, import_scope=None):
"""Deserializes `context_def` into the appropriate ControlFlowContext.
Args:
context_def: ControlFlowContextDef proto
import_scope: Optional `string`. Name scope to add.
Returns:
A ControlFlowContext subclass
"""
if context_def.HasField("cond_ctxt"):
return CondContext.from_proto(context_def.cond_ctxt,
import_scope=import_scope)
if context_def.HasField("while_ctxt"):
return WhileContext.from_proto(context_def.while_ctxt,
import_scope=import_scope)
raise NotImplementedError("Unknown ControlFlowContextDef field: %s"
% context_def.WhichOneof("ctxt"))
ops.register_proto_function(
ops.GraphKeys.COND_CONTEXT,
proto_type=control_flow_pb2.CondContextDef,
to_proto=CondContext.to_proto,
from_proto=CondContext.from_proto)
ops.register_proto_function(
ops.GraphKeys.WHILE_CONTEXT,
proto_type=control_flow_pb2.WhileContextDef,
to_proto=WhileContext.to_proto,
from_proto=WhileContext.from_proto)