blob: 1924c0c1b0af876d17b70891d6428ea7bad33deb [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Control flow statements: loops, conditionals, etc.
Note: most of these operators accept pairs of get_state/set_state functions, to
capture mutations that the corresponding code blocks might make. These
mutations only need to be captured when staging the control flow, and they just
work when reverting to Python behavior.
__Examples__
```
while cond:
self.x += i
```
When the functionalized version is executed as a Python loop, it just works:
```
def loop_body():
self.x += i # works as expected for Python loops
```
But it won't work for TF loops:
```
def loop_body():
self.x += i # self.x has the wrong value!
```
get_state/set_state allow piping the mutations through the loop variables as
well, in effect changing the loop body:
```
def loop_body(self_x):
self.x = self_x # self.x now has the proper value
self.x += i # the original block
self_x = self.x # write self.x back into the loop vars
return self_x
self_x = tf.while_loop(...)
self.x = self_x # the result is not properly captured
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import traceback
import numpy as np
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import variables
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import misc
from tensorflow.python.autograph.utils import tensors
from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.types import distribute
from tensorflow.python.util import nest
PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops.
WARN_INEFFICIENT_UNROLL = True
INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000
INEFFICIENT_UNROLL_MIN_OPS = 1
# TODO(mdan): Use the custom operator pattern instead of type dispatch.
# An example of this pattern is found in the implementation of distributed
# datasets. Before it can be used though, we need to standardize the interface.
def _is_none_or_undef(value):
"""Tests whether a value is None or undefined.
AutoGraph represents undefined symbols using special objects of type Undefined
or UndefinedReturnValue.
Args:
value: value to test
Returns:
Boolean
"""
return ((value is None)
or isinstance(value, variables.UndefinedReturnValue)
or isinstance(value, variables.Undefined))
def _verify_tf_condition(cond, tag):
"""Ensures that the condition can be used in a TF control flow."""
extra_hint = 'to check for None, use `is not None`'
cond = ops.convert_to_tensor_v2(cond)
if cond.dtype != dtypes.bool:
raise ValueError(
'condition of {} expected to be `tf.bool` scalar, got {}'
'; to use as boolean Tensor, use `tf.cast`'
'; {}'.format(tag, cond, extra_hint))
if cond.shape is None or cond.shape.ndims is None:
# TODO(mdan): Consider a explicit size check, if not too slow.
cond = array_ops.reshape(cond, ())
elif cond.shape.ndims > 0:
known_dims = [d for d in cond.shape.as_list() if d is not None]
if np.prod(known_dims) > 1:
raise ValueError(
'condition of {} expected to be `tf.bool` scalar, got {}'
'; {}'.format(tag, cond, extra_hint))
else:
cond = array_ops.reshape(cond, ())
return cond
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
"""Ensures that all values in the state are valid to use in a TF loop.
The init_vars may contain placeholder values derived from first_iter_vars.
Args:
init_vars: initial loop variables (as taken before entering the loop)
symbol_names: corresponding names of the initial loop variables
first_iter_vars: loop variables after one iteration of the loop
"""
if not symbol_names:
return
if first_iter_vars is None:
first_iter_vars = (None,) * len(symbol_names)
assert len(symbol_names) == len(init_vars)
assert len(symbol_names) == len(first_iter_vars)
for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars):
if isinstance(val, variables.UndefinedReturnValue):
if fi_val:
raise ValueError(
'the return value from a TensorFlow loop may only be a {}; got {}'
.format(LEGAL_LOOP_TYPES, type(fi_val)))
else:
# TODO(mdan): This can be handled by removing the return value.
raise NotImplementedError(
'a return statement cannot be placed inside this TensorFlow loop;'
' this may happen if a return statement depends on a'
' static Python condition such as a hyperparameter')
error_msg = None
if val is None:
error_msg = "'{}' may not be None before the loop".format(name)
elif isinstance(val, variables.Undefined):
error_msg = "'{}' must be defined before the loop".format(name)
# This only happens when we could not infer a placeholder for the
# variable. The canonical case when that happens is when _placeholder_value
# couldnot infer a placeholder for it. That means it's of an unknown type
# or it's still undefined after staging one iteration.
if error_msg is not None:
if fi_val:
error_msg += (", unless it's a {}; got {}".format(
LEGAL_LOOP_TYPES, type(fi_val)))
else:
# TODO(mdan): This can be handled by removing the loop var.
error_msg += '.'
raise ValueError(error_msg)
def _is_subshape(left, right):
"""Returns True if left shape is at least as specific as right shape."""
# TODO(mdan): This code should be in TensorShape.
# Note: this is not the same as TensorShape.is_compatible_with, which is
# symmetric.
# This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py.
if right.dims is None:
return True
if left.ndims != right.ndims:
return False
for ldim, rdim in zip(left.dims, right.dims):
if rdim.value is not None and ldim.value != rdim.value:
return False
return True
# TODO(mdan): Remove these verifications once TF ops can properly report names.
def _verify_single_loop_var(
name, check_shape, init, entry, exit_, shape_invariant):
"""Verifies whether the initial, entry and exit values are consistent."""
assert entry is not None, "no TF op should set '{}' to None?".format(name)
if exit_ is None:
raise ValueError("'{}' is None at the end of the iteration.".format(name))
if isinstance(init, (bool, int, float, str, np.ndarray)):
init = ops.convert_to_tensor_v2(init)
if isinstance(entry, (bool, int, float, str, np.ndarray)):
entry = ops.convert_to_tensor_v2(entry)
if isinstance(exit_, (bool, int, float, str, np.ndarray)):
exit_ = ops.convert_to_tensor_v2(exit_)
if (not tensor_util.is_tf_type(entry) or
not tensor_util.is_tf_type(exit_)):
return
# TODO(mdan): Properly account for CompositeTensors.
if (not hasattr(entry, 'dtype') or
not hasattr(exit_, 'dtype')):
return
if (not hasattr(entry, 'shape') or
not hasattr(exit_, 'shape')):
return
if entry.dtype != exit_.dtype:
raise TypeError(
"'{}' has dtype {} before the loop, but dtype {} after one"
' iteration'.format(
name,
entry.dtype.name,
exit_.dtype.name,
))
if check_shape:
exit_shape = exit_.shape
if shape_invariant is None:
entry_shape = entry.shape
if not _is_subshape(exit_shape, entry_shape):
raise ValueError(
"'{}' has shape {} before the loop, but shape {} after one"
' iteration. Use tf.autograph.experimental.set_loop_options to set'
' shape invariants.'.format(name, entry_shape, exit_shape))
else:
init_shape = init.shape
if not _is_subshape(init_shape, shape_invariant):
raise ValueError(
"'{}' has shape {} before the loop, which does not conform with"
' the shape invariant {}.'.format(name, init_shape,
shape_invariant))
if not _is_subshape(exit_shape, shape_invariant):
raise ValueError(
"'{}' has shape {} after one iteration, which does not conform with"
' the shape invariant {}.'.format(
name, exit_shape, shape_invariant))
def _verify_tf_loop_vars(init_vars,
iter_entry_vars,
iter_exit_vars,
symbol_names,
opts,
check_shapes=True):
"""Verifies loop variables for consistency."""
if check_shapes and 'shape_invariants' in opts:
shape_invariants = opts['shape_invariants']
else:
shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
assert len(symbol_names) == len(shape_invariants)
assert len(symbol_names) == len(init_vars)
assert len(symbol_names) == len(iter_entry_vars)
assert len(symbol_names) == len(iter_exit_vars)
for i in range(len(symbol_names)):
name = symbol_names[i]
init = init_vars[i]
entry = iter_entry_vars[i]
exit_ = iter_exit_vars[i]
invariant = shape_invariants[i]
try:
nest.assert_same_structure(init, entry, expand_composites=True)
nest.assert_same_structure(entry, exit_, expand_composites=True)
except (ValueError, TypeError) as e:
raise TypeError("'{}' does not have the same nested structure after one"
' iteration.\n\n{}'.format(name, e))
if invariant is not None:
try:
nest.assert_same_structure(init, invariant, expand_composites=False)
except (ValueError, TypeError) as e:
raise TypeError("'{}' does not have the same nested structure as its"
' corresponding shape invariant.\n\n{}'.format(name, e))
nest.map_structure(
functools.partial(_verify_single_loop_var, name, check_shapes), init,
entry, exit_, invariant)
def verify_single_cond_var(name, body_var, orelse_var):
"""Verifies whether body_var and orelse_var are consistent."""
if body_var is None:
raise ValueError("'{}' is None at the end of the main branch.".format(name))
if orelse_var is None:
raise ValueError(
"'{}' is None at the end of the else branch.".format(name))
if isinstance(body_var, (bool, int, float, str, np.ndarray)):
body_var = ops.convert_to_tensor_v2(body_var)
if isinstance(orelse_var, (bool, int, float, str, np.ndarray)):
orelse_var = ops.convert_to_tensor_v2(orelse_var)
if (not tensor_util.is_tf_type(body_var) or
not tensor_util.is_tf_type(orelse_var)):
return
# TODO(mdan): Properly account for CompositeTensors.
if (not hasattr(body_var, 'dtype') or
not hasattr(orelse_var, 'dtype')):
return
if body_var.dtype != orelse_var.dtype:
raise TypeError(
"'{}' has dtype {} in the main branch, but dtype {} in the else"
' branch'.format(name, body_var.dtype.name,
orelse_var.dtype.name))
def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name):
"""Verifies variables output by a conditional branch for consistency."""
for name, var_ in zip(symbol_names, vars_):
if isinstance(var_, variables.Undefined):
raise ValueError(
"'{}' must also be initialized in the {} branch".format(
name, branch_name))
if isinstance(var_, variables.UndefinedReturnValue):
raise ValueError(
'the {} branch must also have a return statement.'.format(
branch_name))
def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
"""Verifies variables manipulated by a conditional for consistency."""
named_vars = zip(symbol_names, body_vars, orelse_vars)
for name, body_var, orelse_var in named_vars:
try:
nest.assert_same_structure(body_var, orelse_var, expand_composites=True)
except (ValueError, TypeError) as e:
raise TypeError(
"'{}' must have the same nested structure in the main and else"
' branches:\n\n{}'.format(name, str(e)))
nest.map_structure(
functools.partial(verify_single_cond_var, name), body_var, orelse_var)
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Functional form of a for statement.
The loop operates on a state, which includes all symbols that are
variant across loop iterations, excluding the variables local to the loop.
For example, given the loop below that calculates the geometric and
arithmetic means or some numbers:
```
geo_mean = 1
arith_mean = 0
for i in range(n):
a = numbers[i]
geo_mean *= a
arith_mean += a
```
The state is represented by the variables geo_mean and arith_mean. The
`extra_test`, `body`, `get_state` and `set_state` functions must bind to the
original `geo_mean` and `arith_mean` symbols, using `nonlocal`.
The inputs and outputs of the callables representing the loop blocks are not
explicit - instead, these functions must use nonlocal/global for side effects.
The inputs and outputs are instead controlled by the set_state/get_state
functions.
Args:
iter_: The entity being iterated over.
extra_test: Callable with boolean return type.
An additional loop condition.
body: Callable representing the actual loop body.
get_state: Additional callable which can capture additional state (such as
the values of composite symbols). This is only useful when staging the
loop.
set_state: Additional callable which save values captured by get_state back
into the Python environment. This is only useful when staging the loop.
symbol_names: Tuple containing names of the loop variables returned by
get_state.
opts: Optional dict of extra loop parameters.
"""
if tensor_util.is_tf_type(iter_):
if tensors.is_range_tensor(iter_):
_tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
symbol_names, opts)
elif isinstance(iter_, ragged_tensor.RaggedTensor):
_tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
symbol_names, opts)
else:
_known_len_tf_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, dataset_ops.DatasetV2):
_tf_dataset_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, iterator_ops.OwnedIterator):
_tf_iterator_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, ragged_tensor.RaggedTensor):
_tf_ragged_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, distribute.Iterator):
_tf_iterator_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, distribute.Iterable):
# TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)...
_tf_distributed_iterable_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
else:
_py_for_stmt(iter_, extra_test, body, None, None)
def _py_for_stmt(iter_, extra_test, body, get_state, set_state):
"""Overload of for_stmt that executes a Python for loop."""
del get_state, set_state
if __debug__:
checker = _PythonLoopChecker()
before_iteration = checker.before_iteration
after_iteration = checker.after_iteration
before_iteration()
original_body = body
def protected_body(protected_iter):
original_body(protected_iter)
after_iteration()
before_iteration()
body = protected_body
if extra_test is not None:
def guarded_extra_test():
extra_test_result = extra_test()
try:
# Note: Using try/except and not tensor_util.is_tf_type to avoid
# performance degradation.
return bool(extra_test_result)
except errors_impl.OperatorNotAllowedInGraphError:
ag_logging.log(
1,
'Caught error while evaluating loop stop condition',
exc_info=True)
# TODO(mdan): We can pass the location of extra_test and show it here.
raise NotImplementedError(
'break and return statements which depend on a TF condition are not'
' supported in Python for loops. Did you intend to make it a TF'
' loop?\nSee '
'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
'python/autograph/g3doc/reference/limitations.md'
'#consistency-of-control-flow-types for more info.')
if guarded_extra_test():
for target in iter_:
body(target)
if not guarded_extra_test():
break
else:
for target in iter_:
body(target)
def _add_max_iterations_hint(opts, n):
# TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
opts['maximum_iterations'] = n
def _known_len_tf_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over TF entities that admit a length."""
n = py_builtins.len_(iter_)
# TODO(b/117628877): Revisit performance once XLA has the necessary support.
# Note: using a TensorArray creates an extra copy, but can calculate
# gradients more efficiently than StridedSlice.
ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
iter_ = ta.unstack(iter_)
iterate_index = 0
def aug_get_state():
return (iterate_index,) + get_state()
def aug_set_state(aug_loop_vars):
nonlocal iterate_index
# TODO(b/171479293): Drop the lint override.
iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
# The iteration index is not "output" by the for loop. If the iterate
# is used outside the loop, it will appear in the loop vars separately.
set_state(loop_vars)
def aug_body():
nonlocal iterate_index
body(iter_.read(iterate_index))
iterate_index += 1
def aug_test():
main_test = iterate_index < n
if extra_test is not None:
return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
_add_max_iterations_hint(opts, n)
_tf_while_stmt(
aug_test,
aug_body,
aug_get_state,
aug_set_state,
('<internal iterate>',) + symbol_names,
opts,
)
def _tf_ragged_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over TF ragged tensors."""
init_vars = get_state()
_verify_loop_init_vars(init_vars, symbol_names)
# TODO(mdan): Move this into len()? Requires eager support.
if iter_.shape and iter_.shape[0] is not None:
n = iter_.shape[0]
else:
n = iter_.row_lengths()[0]
iterate_index = 0
def aug_get_state():
return (iterate_index,) + get_state()
def aug_set_state(aug_loop_vars):
nonlocal iterate_index
# TODO(b/171479293): Drop the lint override.
iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
# The iteration index is not "output" by the for loop. If the iterate
# is used outside the loop, it will appear in the loop vars separately.
set_state(loop_vars)
def aug_body():
nonlocal iterate_index
body(iter_[iterate_index])
iterate_index += 1
def aug_test():
main_test = iterate_index < n
if extra_test is not None:
return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
_add_max_iterations_hint(opts, n)
_tf_while_stmt(
aug_test,
aug_body,
aug_get_state,
aug_set_state,
('<internal iterate>',) + symbol_names,
opts)
def _tf_range_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over a TF range (and elides it)."""
start, limit, delta = iter_.op.inputs
iterate = start
def _value_or(name, var, default):
if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
return default
return var
def aug_get_state():
state_vars = get_state()
state_vars = tuple(
_value_or(name, var, iterate)
for name, var in zip(symbol_names, state_vars))
return (iterate,) + state_vars
def aug_set_state(aug_loop_vars):
nonlocal iterate
# TODO(b/171479293): Drop the lint override.
iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
# The iteration index is not "output" by the for loop. If the iterate
# is used outside the loop, it will appear in the loop vars separately.
set_state(loop_vars)
def aug_body():
nonlocal iterate
body(iterate)
iterate += delta
def aug_test():
# TODO(b/159713842): Remove once constant folding works.
const_delta = tensor_util.constant_value(delta)
if const_delta is not None:
if const_delta >= 0:
main_test = iterate < limit
else:
main_test = iterate > limit
else:
main_test = math_ops.logical_or(
math_ops.logical_and(delta >= 0, iterate < limit),
math_ops.logical_and(delta < 0, iterate > limit))
if extra_test is not None:
main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
_add_max_iterations_hint(
opts,
math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))
_tf_while_stmt(
aug_test,
aug_body,
aug_get_state,
aug_set_state,
('<internal iterate>',) + symbol_names,
opts)
def _tf_iterator_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over TF Iterators. See for_loop."""
symbol_names = ('<internal has_next>',) + symbol_names
has_next = True
def aug_get_state():
return (has_next,) + get_state()
def aug_set_state(aug_loop_vars):
nonlocal has_next
# TODO(b/171479293): Drop the lint override.
has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
set_state(loop_vars)
init_vars = aug_get_state()
_verify_loop_init_vars(init_vars, symbol_names)
def aug_body():
"""Main body passed to _tf_while_stmt."""
nonlocal has_next
opt_iterate = iter_.get_next_as_optional()
has_next = opt_iterate.has_value()
loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop.
def main_path():
body(opt_iterate.get_value())
new_loop_vars = aug_get_state()
# Note: this verification duplicates the one performed in tf_while_stmt,
# but needs to be done earlier to prevent the tf.cond from blowing up
# first.
_verify_tf_loop_vars(
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
return new_loop_vars
def noop_path():
return loop_vars
# TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
# Calling set_state so that get_state() _tf_while_loop sees the conditional
# tensors.
aug_set_state(
control_flow_ops.cond(has_next, main_path, noop_path))
def aug_test():
# This value takes a complicated path to get here:
# prev_iteration_body -> get_state -> tf.while_loop (as loop var)
# -> current_iteration_body -> set_state -> has_next
main_test = has_next
if extra_test is not None:
return control_flow_ops.cond(main_test, extra_test, lambda: False)
return main_test
_tf_while_stmt(
aug_test,
aug_body,
aug_get_state,
aug_set_state,
symbol_names,
opts)
def _general_purpose_scan(ds, init_state, body):
"""Variant of Dataset.scan with semantics of general-purpose computation."""
# Datasets are typically intended for data preprocessing. However, in
# autograph loops they usually appear as general-purpose computations (for
# example, a custom training loop). These two use cases require significantly
# different optimization policies, the most important of which is the device
# placement. The flag override for use_default_device below instructs the
# runtime to treat the computation as general-purpose, rather than data
# preprocessing.
# TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
# TODO(mdan): Don't use private symbols.
# pylint:disable=protected-access
return dataset_ops._ScanDataset(
ds, init_state, body, use_default_device=False)
def _tf_dataset_for_stmt(
ds, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of _dataset_for_stmt with early stopping. See for_stmt."""
# Note: This is easier to follow with the insight that the computations in
# a dataset pipeline are transposed (aka fused).
# For example, given a pipeline input -> scan -> take_while -> reduce,
# and a dataset with input [1, 2, 3], the computations occur in the following
# order:
# reduce(take_while(scan(1)))
# reduce(take_while(scan(2)))
# reduce(take_while(scan(3)))
init_vars = get_state()
_verify_loop_init_vars(init_vars, symbol_names)
# Workaround for Dataset.reduce not allowing empty state tensors - create
# a dummy state variable that remains unused.
# TODO(mdan): reduce should allow and match empty structures.
if not init_vars:
init_vars = (constant_op.constant(0),)
symbol_names = ('<internal dummy>',)
def dummy_set_state(unused_dummy):
pass
def dummy_get_state():
return (constant_op.constant(0),)
get_state, set_state = dummy_get_state, dummy_set_state
def scan_body(scan_state, scan_inputs):
"""Main body of the Dataset.scan."""
loop_vars, iterate = scan_state, scan_inputs
set_state(loop_vars)
def main_path():
body(iterate)
new_loop_vars = get_state()
_verify_tf_loop_vars(
init_vars, loop_vars, new_loop_vars, symbol_names, opts,
check_shapes=False)
return new_loop_vars
if extra_test is not None:
extra_cond = extra_test()
new_loop_vars = control_flow_ops.cond(
extra_cond, main_path, lambda: loop_vars)
else:
# TODO(mdan): the optimizer should be able to remove an invariant cond?
extra_cond = (constant_op.constant(True),) # dummy value, unused
new_loop_vars = main_path()
scan_outputs = new_loop_vars, extra_cond
new_scan_state = new_loop_vars
return new_scan_state, scan_outputs
def take_while_predicate(unused_loop_vars, extra_cond):
return extra_cond
def reduce_body(unused_reduce_state, scan_outputs):
output_loop_vars, unused_extra_cond = scan_outputs
new_reduce_state = output_loop_vars
return new_reduce_state
ds = _general_purpose_scan(ds, init_vars, scan_body)
if extra_test is not None:
ds = ds.apply(take_while_ops.take_while(take_while_predicate))
final_loop_vars = ds.reduce(init_vars, reduce_body)
set_state(final_loop_vars)
def _tf_distributed_iterable_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts):
"""Overload of for_stmt that iterates over TF distributed datasets."""
if extra_test is not None:
raise NotImplementedError(
'break and return statements are not yet supported in '
'for ... in distributed input loops.')
init_vars = get_state()
_verify_loop_init_vars(init_vars, symbol_names)
if 'shape_invariants' in opts:
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
opts['shape_invariants'], init_vars)
def reduce_body(loop_vars, iterate):
set_state(loop_vars)
body(iterate)
new_loop_vars = get_state()
_verify_tf_loop_vars(
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
return new_loop_vars
set_state(iter_.reduce(init_vars, reduce_body))
def while_stmt(test, body, get_state, set_state, symbol_names, opts):
"""Functional form of a while statement.
The loop operates on a so-called state, which includes all symbols that are
variant across loop iterations. In what follows we refer to state as either
a tuple of entities that represent an actual state, or a list of arguments
of the corresponding types.
The inputs and outputs of the callables representing the loop blocks are not
explicit - instead, these functions must use nonlocal/global for side effects.
The inputs and outputs are instead controlled by the set_state/get_state
functions.
Args:
test: Callable with boolean return type. The loop condition.
body: Callable representing the actual loop body.
get_state: Additional callable which can capture additional state (such as
the values of composite symbols). This is only useful when staging the
loop.
set_state: Additional callable which save values captured by get_state back
into the Python environment. This is only useful when staging the loop.
symbol_names: Tuple containing the names of all loop variables.
opts: Optional dict of extra loop parameters.
Returns:
Tuple containing the final state.
"""
# Evaluate the initial test once in order to do the dispatch. The evaluation
# is isolated to minimize unwanted side effects.
# TODO(mdan): Do a full iteration - some state types might lower to Tensor.
with func_graph.FuncGraph('tmp').as_default():
init_test = test()
# TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
# with the re-evaluation of `test` that `_tf_while_stmt` will make.
if tensors.is_dense_tensor(init_test):
_tf_while_stmt(test, body, get_state, set_state, symbol_names, opts)
return
# Normal Python: We already consumed one evaluation of `test`; consistently,
# unroll one iteration before dispatching to a normal loop.
# TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
if not init_test:
return
body()
_py_while_stmt(test, body, get_state, set_state, opts)
class _PythonLoopChecker(object):
"""Verifies Python loops for TF-specific limits."""
__slots__ = (
'iterations',
'check_inefficient_unroll',
'check_op_count_after_iteration',
'ops_before_iteration',
)
def __init__(self):
self.iterations = 1
self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL
# Triggered when we decided to test the op counts.
self.check_op_count_after_iteration = False
def _get_ops(self):
return ops.get_default_graph().get_operations()
def _check_unroll_limits(self):
if self.iterations > PYTHON_MAX_ITERATIONS:
raise ValueError('iteration limit exceeded')
def _stop_checking_inefficient_unroll(self):
self.check_inefficient_unroll = False
self.check_op_count_after_iteration = False
self.ops_before_iteration = None
def _verify_inefficient_unroll(self):
"""Checks for possibly-inefficient creation of ops in a Python loop."""
assert self.ops_before_iteration is not None
ops_after_iteration = self._get_ops()
new_ops = tuple(
op for op in ops_after_iteration if op not in self.ops_before_iteration)
if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS:
return False
ag_logging.warning(
'Large unrolled loop detected. Did you mean to use a TF loop?'
' The following ops were created after iteration %s: %s'
'\nSee'
' https://github.com/tensorflow/tensorflow/blob/master/'
'tensorflow/python/autograph/g3doc/reference/common_errors.md'
'#warning-large-unrolled-loop-detected'
'\n'
'Location:'
'\n%s'
'', self.iterations, new_ops, '\n'.join(traceback.format_stack()))
return True
def before_iteration(self):
"""Called before each iteration in a Python loop."""
if (self.check_inefficient_unroll and
self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS):
self.ops_before_iteration = self._get_ops()
self.check_op_count_after_iteration = True
def after_iteration(self):
"""Called after each iteration in a Python loop."""
self.iterations += 1
self._check_unroll_limits()
if self.check_op_count_after_iteration:
did_warn = self._verify_inefficient_unroll()
if did_warn:
self._stop_checking_inefficient_unroll() # Only warn once.
elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
# Once deciding to check the op counts, only do it for a few iterations.
self._stop_checking_inefficient_unroll()
def _py_while_stmt(test, body, get_state, set_state, opts):
"""Overload of while_stmt that executes a Python while loop."""
del opts, get_state, set_state
if __debug__:
checker = _PythonLoopChecker()
before_iteration = checker.before_iteration
after_iteration = checker.after_iteration
before_iteration()
original_body = body
def protected_body():
original_body()
after_iteration()
before_iteration()
body = protected_body
def guarded_test():
test_result = test()
try:
# Note: Using try/except and not tensor_util.is_tf_type to avoid
# performance degradation.
return bool(test_result)
except errors_impl.OperatorNotAllowedInGraphError:
ag_logging.log(
1,
'Caught error while evaluating while loop condition',
exc_info=True)
# TODO(mdan): distinguish beteen these two cases.
raise NotImplementedError(
'The condition of while loop started as non-Tensor, then changed to'
' Tensor. This may happen either because variables changed type, or'
' when a break or return statement inside the loop depends on a'
' Tensor condition. In both cases, changing to a TF loop should'
' remove the error.\nSee '
'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
'python/autograph/g3doc/reference/limitations.md'
'#consistency-of-control-flow-types for more info.')
while guarded_test():
body()
def _shape_invariants_mapping_to_positional_list(mapping, keys):
# The keys are not expected to be hashable.
mapping = {id(k): (k, v) for k, v in mapping}
result = []
for k in keys:
map_key, map_val = mapping.get(id(k), (None, None))
result.append(map_val if map_key is k else None)
return tuple(result)
# Textual description of what a legal TF loop variable is. This description
# summarizes types that _placeholder_value below can handle. Keep the two
# together and in sync.
LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof'
def _placeholder_value(like, shape_invariant, original=None):
"""Constructs a (dummy) placeholder value for a loop-initialized variable.
Args:
like: Any object. The value created by the first iteration of the loop.
If a Python scalar, the placeholder will be the zero value of that type.
If a Tensor, the placeholder will be a zero tensor of matching shape and
dtype. If a list, dict or tuple, the placeholder will be an identical
structure of placeholders.
shape_invariant: The shape invariant specified by the user (or None, if
nothing was specified) for the respective variable.
original: Any object. The value of the variable prior to entering the loop.
Typically, this is one of the special "Undefined" value, because that's
when a placeholder is needed.
Returns:
Either a zero value of structure, shape and dtype mathing 'like', or
'original', if no such zero value could be created.
"""
if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
return original, None
elif isinstance(like, (int, float, bool)):
return type(like)(0), None
elif tensor_util.is_tf_type(like):
like_shape = shape_invariant if shape_invariant is not None else like.shape
if like_shape is None or like_shape.rank is None:
return array_ops.zeros((), like.dtype), like_shape
# If the shape contains dynamic values, set the corresponding starting
# dimension to either zero or what the shape invariant specified.
placeholder_shape = []
has_dynamic_dims = False
for s, i in zip(like.shape, like_shape):
if i is None:
like_dim = 0
elif isinstance(i, tensor_shape.Dimension):
if i.value is None:
like_dim = 0
else:
like_dim = i.value
else:
like_dim = i
if s is None:
placeholder_shape.append(like_dim)
has_dynamic_dims = True
elif isinstance(s, tensor_shape.Dimension):
if s.value is None:
placeholder_shape.append(like_dim)
has_dynamic_dims = True
else:
placeholder_shape.append(s.value)
else:
placeholder_shape.append(s)
if has_dynamic_dims:
invariant = like_shape
else:
invariant = None
return array_ops.zeros(placeholder_shape, like.dtype), invariant
elif isinstance(like, (list, tuple, dict)):
if shape_invariant is None:
zipped = nest.map_structure(lambda v: _placeholder_value(v, None),
nest.flatten(like))
else:
zipped = nest.map_structure(_placeholder_value, nest.flatten(like),
nest.flatten(shape_invariant))
vals, invars = zip(*zipped)
return (nest.pack_sequence_as(like,
vals), nest.pack_sequence_as(like, invars))
return original, None
def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls,
shape_invariants, symbol_names):
"""Makes a best-effort attempt to substitute undefineds with placeholders.
Note: this substitution requires two things to happen:
1. the types of loop variables could be inferred (usually by staging one
iteration)
2. these types could be replaced by placeholders (e.g. zero values, for
tensors.
Args:
body: a function representing the loop body. See while_stmt.
get_state: state getter for the loop statement. See while_stmt.
set_state: state getter for the loop statement. See while_stmt.
init_vars: loop variables before entering the loop. See while_stmt.
nulls: list of boolean flags indicating whether the corresponding loop
var is None or undefined.
shape_invariants: user-specified shape invariant for each loop variable.
symbol_names: list of loop variable names. See while_stmt.
Returns:
A tuple (success, new_init_vars). success is a boolean flag indicating
whether types could be successfully inferred (step 1 above). new_init_vars
contains the loop vars, with None or undefined values replaced by
placeholders, where possible (step 2 above).
"""
state_modified = False
try:
# Stage an iteration of the loop body in a temporary graph.
with func_graph.FuncGraph('tmp').as_default():
# This call to set_state helps report nicer error messages when symbols
# are inconsistently used.
# Another complication is that non_tensor values will be autocast to
# Tensor by while_loop, and their static value lost. So we need to account
# that here.
def autocast_to_tensor(v):
if isinstance(
v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)):
init_val = ops.convert_to_tensor_v2(v)
return array_ops.placeholder(init_val.dtype, init_val.shape)
return v
autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars)
set_state(autocast_init_vars)
state_modified = True
body()
first_iter_vars = get_state()
except (UnboundLocalError, TypeError, ValueError, KeyError):
ag_logging.log(1, 'Caught error while staging loop body', exc_info=True)
# Fall back to the old functionality. It will likely result in an input
# validation failure.
first_iter_vars = None
finally:
if state_modified:
set_state(init_vars)
if first_iter_vars is not None:
# Note: the actual placeholder value doesn't matter, because as the staging
# proved, it will be replaced by an actual value before being read.
inits_and_invariants = tuple(
(_placeholder_value(iv, i, v) if n else (v, None))
for v, n, iv, i in zip(init_vars, nulls, first_iter_vars,
shape_invariants))
init_vars, extra_shape_invariants = zip(*inits_and_invariants)
success = True
else:
success = False
# This check runs regardless, in case we captured non-Tensor inputs.
_verify_loop_init_vars(init_vars, symbol_names, first_iter_vars)
return success, init_vars, extra_shape_invariants
def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars):
"""Creates an error message asking for the loop to iterate at least once."""
var_names = []
for sn, n, v in zip(symbol_names, nulls, init_vars):
if not n:
continue
if isinstance(v, variables.UndefinedReturnValue):
var_names.append('the function return value')
else:
var_names.append(sn)
var_names = ', '.join(var_names)
return 'loop must iterate at least once to initialize {}'.format(var_names)
def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
"""Overload of while_stmt that stages a TF while_stmt."""
init_vars = get_state()
orig_init_vars = init_vars
nulls = tuple(_is_none_or_undef(v) for v in init_vars)
if any(nulls):
shape_invars_by_init_vals = {
id(v): i for v, i in opts.get('shape_invariants', ())
}
shape_invariants = tuple(
shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars)
(require_one_iteration, init_vars,
extra_shape_invariants) = _try_handling_undefineds(body, get_state,
set_state, init_vars,
nulls, shape_invariants,
symbol_names)
else:
require_one_iteration = False
if require_one_iteration:
merged_shape_invariants = dict(shape_invars_by_init_vals)
# This has two roles:
# 1. Shape invariants are remapped from the old init vars to the new ones.
# 2. Any new shape invariants created by the init vars are kept, but only
# if the user didn't already specified some.
for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants):
merged_invariant = merged_shape_invariants.get(id(v), ni)
if merged_invariant is not None:
merged_shape_invariants[id(nv)] = merged_invariant
merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)])
for nv in init_vars
if id(nv) in merged_shape_invariants)
if merged_shape_invariants:
opts = dict(**opts)
opts['shape_invariants'] = merged_shape_invariants
def aug_test(*loop_vars):
if require_one_iteration:
loop_vars = loop_vars[1:]
set_state(loop_vars)
return _verify_tf_condition(test(), 'while loop')
def aug_body(*loop_vars):
if require_one_iteration:
loop_vars = loop_vars[1:]
set_state(loop_vars)
body()
new_loop_vars = get_state()
_verify_tf_loop_vars(
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
if require_one_iteration:
new_loop_vars = (True,) + new_loop_vars
return new_loop_vars
if 'shape_invariants' in opts:
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
opts['shape_invariants'], init_vars)
while_loop_opts = dict(opts)
while_loop_opts.pop('iterate_names', None)
# Non-v2 while_loop unpacks the results when there is only one return value.
# This enforces consistency across versions.
while_loop_opts['return_same_structure'] = True
if require_one_iteration:
aug_init_vars = (False,) + init_vars
if 'shape_invariants' in while_loop_opts:
while_loop_opts['shape_invariants'] = (
(None,) + while_loop_opts['shape_invariants'])
else:
aug_init_vars = init_vars
final_loop_vars = control_flow_ops.while_loop(
aug_test, aug_body, aug_init_vars, **while_loop_opts)
if require_one_iteration:
with ops.control_dependencies([
control_flow_ops.Assert(final_loop_vars[0], [
_runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars)
])
]):
final_loop_vars = nest.map_structure(
lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v),
final_loop_vars[1:],
)
set_state(final_loop_vars)
def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
"""Functional form of an if statement.
The conditional operates on a state, which includes all symbols whose values
are a function of the branch taken.
For example, given the code below that calculates the abs function:
```
x = 1
if x > 0:
x = -x
```
The state is represented by the variable `x`. The `body, `orelse` and
`set_state` functions must bind to the original `x` symbol, using `nonlocal`.
The inputs and outputs of the callables representing the loop blocks are not
explicit - instead, these functions must use nonlocal/global for side effects.
The inputs and outputs are instead controlled by the set_state/get_state
functions.
Args:
cond: Boolean.
body: Callable representing the main block of the conditional.
orelse: Callable representing the else block of the conditional.
get_state: Function that returns a tuple containing the values of all
composite symbols modified within the conditional. This allows access to
state that branches may mutate through side effects. This function is not
needed and should not be called when dispatching to code matching Python's
default semantics. This is useful for checkpointing to avoid unintended
side-effects when staging requires evaluating all code-paths.
set_state: Function to set the values of all composite symbols modified
within the conditional. This is the complement to get_state, used to
restore checkpointed values. The single argument a tuple containing values
for each composite symbol that may be modified in a branch of the
conditional. The is usually the result of a call to get_state.
symbol_names: Tuple containing basic loop var names.
nouts: Number of variables output by the statement. Vars which are
not outputs will not be passed through staged control flow such as
tf.cond. This includes variables that are defined before the conditional,
but are not used after it.
"""
# Note: tf.cond doesn't support SparseTensor.
if tensors.is_dense_tensor(cond):
_tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
else:
_py_if_stmt(cond, body, orelse)
def _tf_if_stmt(
cond, body, orelse, get_state, set_state, symbol_names, nouts):
"""Overload of if_stmt that stages a TF cond."""
cond = _verify_tf_condition(cond, 'if statement')
if not nouts:
prev_get_state, prev_set_state = get_state, set_state
# Control flow V1 wants at least one output.
get_state = lambda: (0,) + prev_get_state()
set_state = lambda v: prev_set_state(v[1:])
symbol_names += ('<unused dummy>',)
nouts = 1
init_vars = get_state()
# TODO(mdan): Use nonlocal once we no longer need to support py2.
new_body_vars_ = [None]
new_orelse_vars_ = [None]
def aug_body():
set_state(init_vars)
body()
new_body_vars = get_state()
new_body_vars = new_body_vars[:nouts]
new_body_vars_[0] = new_body_vars
_verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main')
if new_orelse_vars_[0] is not None:
_verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names)
return new_body_vars
def aug_orelse():
set_state(init_vars)
orelse()
new_orelse_vars = get_state()
new_orelse_vars = new_orelse_vars[:nouts]
new_orelse_vars_[0] = new_orelse_vars
_verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else')
if new_body_vars_[0] is not None:
_verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names)
return new_orelse_vars
final_cond_vars = control_flow_ops.cond(
cond, aug_body, aug_orelse, strict=True)
final_cond_vars = final_cond_vars + init_vars[nouts:]
set_state(final_cond_vars)
def _py_if_stmt(cond, body, orelse):
"""Overload of if_stmt that executes a Python if statement."""
return body() if cond else orelse()