| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Control flow statements: loops, conditionals, etc. |
| |
| Note: most of these operators accept pairs of get_state/set_state functions, to |
| capture mutations that the corresponding code blocks might make. These |
| mutations only need to be captured when staging the control flow, and they just |
| work when reverting to Python behavior. |
| |
| __Examples__ |
| |
| ``` |
| while cond: |
| self.x += i |
| ``` |
| |
| When the functionalized version is executed as a Python loop, it just works: |
| |
| ``` |
| def loop_body(): |
| self.x += i # works as expected for Python loops |
| ``` |
| |
| But it won't work for TF loops: |
| |
| ``` |
| def loop_body(): |
| self.x += i # self.x has the wrong value! |
| ``` |
| |
| get_state/set_state allow piping the mutations through the loop variables as |
| well, in effect changing the loop body: |
| |
| ``` |
| def loop_body(self_x): |
| self.x = self_x # self.x now has the proper value |
| self.x += i # the original block |
| self_x = self.x # write self.x back into the loop vars |
| return self_x |
| |
| self_x = tf.while_loop(...) |
| self.x = self_x # the result is not properly captured |
| ``` |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| import traceback |
| |
| import numpy as np |
| |
| from tensorflow.python.autograph.operators import py_builtins |
| from tensorflow.python.autograph.operators import variables |
| from tensorflow.python.autograph.utils import ag_logging |
| from tensorflow.python.autograph.utils import misc |
| from tensorflow.python.autograph.utils import tensors |
| from tensorflow.python.data.experimental.ops import take_while_ops |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import iterator_ops |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors_impl |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import control_flow_util |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.types import distribute |
| from tensorflow.python.util import nest |
| |
| |
| PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. |
| WARN_INEFFICIENT_UNROLL = True |
| INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000 |
| INEFFICIENT_UNROLL_MIN_OPS = 1 |
| |
| |
| # TODO(mdan): Use the custom operator pattern instead of type dispatch. |
| # An example of this pattern is found in the implementation of distributed |
| # datasets. Before it can be used though, we need to standardize the interface. |
| |
| |
| def _is_none_or_undef(value): |
| """Tests whether a value is None or undefined. |
| |
| AutoGraph represents undefined symbols using special objects of type Undefined |
| or UndefinedReturnValue. |
| |
| Args: |
| value: value to test |
| Returns: |
| Boolean |
| """ |
| return ((value is None) |
| or isinstance(value, variables.UndefinedReturnValue) |
| or isinstance(value, variables.Undefined)) |
| |
| |
| def _verify_tf_condition(cond, tag): |
| """Ensures that the condition can be used in a TF control flow.""" |
| extra_hint = 'to check for None, use `is not None`' |
| cond = ops.convert_to_tensor_v2(cond) |
| |
| if cond.dtype != dtypes.bool: |
| raise ValueError( |
| 'condition of {} expected to be `tf.bool` scalar, got {}' |
| '; to use as boolean Tensor, use `tf.cast`' |
| '; {}'.format(tag, cond, extra_hint)) |
| |
| if cond.shape is None or cond.shape.ndims is None: |
| # TODO(mdan): Consider a explicit size check, if not too slow. |
| cond = array_ops.reshape(cond, ()) |
| |
| elif cond.shape.ndims > 0: |
| known_dims = [d for d in cond.shape.as_list() if d is not None] |
| if np.prod(known_dims) > 1: |
| raise ValueError( |
| 'condition of {} expected to be `tf.bool` scalar, got {}' |
| '; {}'.format(tag, cond, extra_hint)) |
| else: |
| cond = array_ops.reshape(cond, ()) |
| |
| return cond |
| |
| |
| def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None): |
| """Ensures that all values in the state are valid to use in a TF loop. |
| |
| The init_vars may contain placeholder values derived from first_iter_vars. |
| |
| Args: |
| init_vars: initial loop variables (as taken before entering the loop) |
| symbol_names: corresponding names of the initial loop variables |
| first_iter_vars: loop variables after one iteration of the loop |
| """ |
| if not symbol_names: |
| return |
| if first_iter_vars is None: |
| first_iter_vars = (None,) * len(symbol_names) |
| |
| assert len(symbol_names) == len(init_vars) |
| assert len(symbol_names) == len(first_iter_vars) |
| for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars): |
| if isinstance(val, variables.UndefinedReturnValue): |
| if fi_val: |
| raise ValueError( |
| 'the return value from a TensorFlow loop may only be a {}; got {}' |
| .format(LEGAL_LOOP_TYPES, type(fi_val))) |
| else: |
| # TODO(mdan): This can be handled by removing the return value. |
| raise NotImplementedError( |
| 'a return statement cannot be placed inside this TensorFlow loop;' |
| ' this may happen if a return statement depends on a' |
| ' static Python condition such as a hyperparameter') |
| |
| error_msg = None |
| if val is None: |
| error_msg = "'{}' may not be None before the loop".format(name) |
| elif isinstance(val, variables.Undefined): |
| error_msg = "'{}' must be defined before the loop".format(name) |
| |
| # This only happens when we could not infer a placeholder for the |
| # variable. The canonical case when that happens is when _placeholder_value |
| # couldnot infer a placeholder for it. That means it's of an unknown type |
| # or it's still undefined after staging one iteration. |
| if error_msg is not None: |
| if fi_val: |
| error_msg += (", unless it's a {}; got {}".format( |
| LEGAL_LOOP_TYPES, type(fi_val))) |
| else: |
| # TODO(mdan): This can be handled by removing the loop var. |
| error_msg += '.' |
| raise ValueError(error_msg) |
| |
| |
| def _is_subshape(left, right): |
| """Returns True if left shape is at least as specific as right shape.""" |
| # TODO(mdan): This code should be in TensorShape. |
| # Note: this is not the same as TensorShape.is_compatible_with, which is |
| # symmetric. |
| # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. |
| if right.dims is None: |
| return True |
| if left.ndims != right.ndims: |
| return False |
| for ldim, rdim in zip(left.dims, right.dims): |
| if rdim.value is not None and ldim.value != rdim.value: |
| return False |
| return True |
| |
| |
| # TODO(mdan): Remove these verifications once TF ops can properly report names. |
| def _verify_single_loop_var( |
| name, check_shape, init, entry, exit_, shape_invariant): |
| """Verifies whether the initial, entry and exit values are consistent.""" |
| assert entry is not None, "no TF op should set '{}' to None?".format(name) |
| if exit_ is None: |
| raise ValueError("'{}' is None at the end of the iteration.".format(name)) |
| |
| if isinstance(init, (bool, int, float, str, np.ndarray)): |
| init = ops.convert_to_tensor_v2(init) |
| if isinstance(entry, (bool, int, float, str, np.ndarray)): |
| entry = ops.convert_to_tensor_v2(entry) |
| if isinstance(exit_, (bool, int, float, str, np.ndarray)): |
| exit_ = ops.convert_to_tensor_v2(exit_) |
| |
| if (not tensor_util.is_tf_type(entry) or |
| not tensor_util.is_tf_type(exit_)): |
| return |
| |
| # TODO(mdan): Properly account for CompositeTensors. |
| if (not hasattr(entry, 'dtype') or |
| not hasattr(exit_, 'dtype')): |
| return |
| if (not hasattr(entry, 'shape') or |
| not hasattr(exit_, 'shape')): |
| return |
| |
| if entry.dtype != exit_.dtype: |
| raise TypeError( |
| "'{}' has dtype {} before the loop, but dtype {} after one" |
| ' iteration'.format( |
| name, |
| entry.dtype.name, |
| exit_.dtype.name, |
| )) |
| if check_shape: |
| exit_shape = exit_.shape |
| if shape_invariant is None: |
| entry_shape = entry.shape |
| if not _is_subshape(exit_shape, entry_shape): |
| raise ValueError( |
| "'{}' has shape {} before the loop, but shape {} after one" |
| ' iteration. Use tf.autograph.experimental.set_loop_options to set' |
| ' shape invariants.'.format(name, entry_shape, exit_shape)) |
| else: |
| init_shape = init.shape |
| if not _is_subshape(init_shape, shape_invariant): |
| raise ValueError( |
| "'{}' has shape {} before the loop, which does not conform with" |
| ' the shape invariant {}.'.format(name, init_shape, |
| shape_invariant)) |
| if not _is_subshape(exit_shape, shape_invariant): |
| raise ValueError( |
| "'{}' has shape {} after one iteration, which does not conform with" |
| ' the shape invariant {}.'.format( |
| name, exit_shape, shape_invariant)) |
| |
| |
| def _verify_tf_loop_vars(init_vars, |
| iter_entry_vars, |
| iter_exit_vars, |
| symbol_names, |
| opts, |
| check_shapes=True): |
| """Verifies loop variables for consistency.""" |
| if check_shapes and 'shape_invariants' in opts: |
| shape_invariants = opts['shape_invariants'] |
| else: |
| shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) |
| |
| assert len(symbol_names) == len(shape_invariants) |
| assert len(symbol_names) == len(init_vars) |
| assert len(symbol_names) == len(iter_entry_vars) |
| assert len(symbol_names) == len(iter_exit_vars) |
| |
| for i in range(len(symbol_names)): |
| name = symbol_names[i] |
| init = init_vars[i] |
| entry = iter_entry_vars[i] |
| exit_ = iter_exit_vars[i] |
| invariant = shape_invariants[i] |
| |
| try: |
| nest.assert_same_structure(init, entry, expand_composites=True) |
| nest.assert_same_structure(entry, exit_, expand_composites=True) |
| except (ValueError, TypeError) as e: |
| raise TypeError("'{}' does not have the same nested structure after one" |
| ' iteration.\n\n{}'.format(name, e)) |
| if invariant is not None: |
| try: |
| nest.assert_same_structure(init, invariant, expand_composites=False) |
| except (ValueError, TypeError) as e: |
| raise TypeError("'{}' does not have the same nested structure as its" |
| ' corresponding shape invariant.\n\n{}'.format(name, e)) |
| |
| nest.map_structure( |
| functools.partial(_verify_single_loop_var, name, check_shapes), init, |
| entry, exit_, invariant) |
| |
| |
| def verify_single_cond_var(name, body_var, orelse_var): |
| """Verifies whether body_var and orelse_var are consistent.""" |
| if body_var is None: |
| raise ValueError("'{}' is None at the end of the main branch.".format(name)) |
| if orelse_var is None: |
| raise ValueError( |
| "'{}' is None at the end of the else branch.".format(name)) |
| |
| if isinstance(body_var, (bool, int, float, str, np.ndarray)): |
| body_var = ops.convert_to_tensor_v2(body_var) |
| |
| if isinstance(orelse_var, (bool, int, float, str, np.ndarray)): |
| orelse_var = ops.convert_to_tensor_v2(orelse_var) |
| |
| if (not tensor_util.is_tf_type(body_var) or |
| not tensor_util.is_tf_type(orelse_var)): |
| return |
| |
| # TODO(mdan): Properly account for CompositeTensors. |
| if (not hasattr(body_var, 'dtype') or |
| not hasattr(orelse_var, 'dtype')): |
| return |
| |
| if body_var.dtype != orelse_var.dtype: |
| raise TypeError( |
| "'{}' has dtype {} in the main branch, but dtype {} in the else" |
| ' branch'.format(name, body_var.dtype.name, |
| orelse_var.dtype.name)) |
| |
| |
| def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): |
| """Verifies variables output by a conditional branch for consistency.""" |
| for name, var_ in zip(symbol_names, vars_): |
| if isinstance(var_, variables.Undefined): |
| raise ValueError( |
| "'{}' must also be initialized in the {} branch".format( |
| name, branch_name)) |
| if isinstance(var_, variables.UndefinedReturnValue): |
| raise ValueError( |
| 'the {} branch must also have a return statement.'.format( |
| branch_name)) |
| |
| |
| def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): |
| """Verifies variables manipulated by a conditional for consistency.""" |
| named_vars = zip(symbol_names, body_vars, orelse_vars) |
| |
| for name, body_var, orelse_var in named_vars: |
| try: |
| nest.assert_same_structure(body_var, orelse_var, expand_composites=True) |
| except (ValueError, TypeError) as e: |
| raise TypeError( |
| "'{}' must have the same nested structure in the main and else" |
| ' branches:\n\n{}'.format(name, str(e))) |
| nest.map_structure( |
| functools.partial(verify_single_cond_var, name), body_var, orelse_var) |
| |
| |
| def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Functional form of a for statement. |
| |
| The loop operates on a state, which includes all symbols that are |
| variant across loop iterations, excluding the variables local to the loop. |
| |
| For example, given the loop below that calculates the geometric and |
| arithmetic means or some numbers: |
| |
| ``` |
| geo_mean = 1 |
| arith_mean = 0 |
| for i in range(n): |
| a = numbers[i] |
| geo_mean *= a |
| arith_mean += a |
| ``` |
| |
| The state is represented by the variables geo_mean and arith_mean. The |
| `extra_test`, `body`, `get_state` and `set_state` functions must bind to the |
| original `geo_mean` and `arith_mean` symbols, using `nonlocal`. |
| |
| The inputs and outputs of the callables representing the loop blocks are not |
| explicit - instead, these functions must use nonlocal/global for side effects. |
| The inputs and outputs are instead controlled by the set_state/get_state |
| functions. |
| |
| Args: |
| iter_: The entity being iterated over. |
| extra_test: Callable with boolean return type. |
| An additional loop condition. |
| body: Callable representing the actual loop body. |
| get_state: Additional callable which can capture additional state (such as |
| the values of composite symbols). This is only useful when staging the |
| loop. |
| set_state: Additional callable which save values captured by get_state back |
| into the Python environment. This is only useful when staging the loop. |
| symbol_names: Tuple containing names of the loop variables returned by |
| get_state. |
| opts: Optional dict of extra loop parameters. |
| """ |
| if tensor_util.is_tf_type(iter_): |
| if tensors.is_range_tensor(iter_): |
| _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, |
| symbol_names, opts) |
| elif isinstance(iter_, ragged_tensor.RaggedTensor): |
| _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, |
| symbol_names, opts) |
| else: |
| _known_len_tf_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| elif isinstance(iter_, dataset_ops.DatasetV2): |
| _tf_dataset_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| elif isinstance(iter_, iterator_ops.OwnedIterator): |
| _tf_iterator_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| elif isinstance(iter_, ragged_tensor.RaggedTensor): |
| _tf_ragged_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| elif isinstance(iter_, distribute.Iterator): |
| _tf_iterator_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| elif isinstance(iter_, distribute.Iterable): |
| # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)... |
| _tf_distributed_iterable_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts) |
| |
| else: |
| _py_for_stmt(iter_, extra_test, body, None, None) |
| |
| |
| def _py_for_stmt(iter_, extra_test, body, get_state, set_state): |
| """Overload of for_stmt that executes a Python for loop.""" |
| del get_state, set_state |
| |
| if __debug__: |
| checker = _PythonLoopChecker() |
| before_iteration = checker.before_iteration |
| after_iteration = checker.after_iteration |
| before_iteration() |
| |
| original_body = body |
| def protected_body(protected_iter): |
| original_body(protected_iter) |
| after_iteration() |
| before_iteration() |
| body = protected_body |
| |
| if extra_test is not None: |
| def guarded_extra_test(): |
| extra_test_result = extra_test() |
| try: |
| # Note: Using try/except and not tensor_util.is_tf_type to avoid |
| # performance degradation. |
| return bool(extra_test_result) |
| except errors_impl.OperatorNotAllowedInGraphError: |
| ag_logging.log( |
| 1, |
| 'Caught error while evaluating loop stop condition', |
| exc_info=True) |
| # TODO(mdan): We can pass the location of extra_test and show it here. |
| raise NotImplementedError( |
| 'break and return statements which depend on a TF condition are not' |
| ' supported in Python for loops. Did you intend to make it a TF' |
| ' loop?\nSee ' |
| 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' |
| 'python/autograph/g3doc/reference/limitations.md' |
| '#consistency-of-control-flow-types for more info.') |
| |
| if guarded_extra_test(): |
| for target in iter_: |
| body(target) |
| if not guarded_extra_test(): |
| break |
| |
| else: |
| for target in iter_: |
| body(target) |
| |
| |
| def _add_max_iterations_hint(opts, n): |
| # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations. |
| if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): |
| opts['maximum_iterations'] = n |
| |
| |
| def _known_len_tf_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of for_stmt that iterates over TF entities that admit a length.""" |
| n = py_builtins.len_(iter_) |
| |
| # TODO(b/117628877): Revisit performance once XLA has the necessary support. |
| # Note: using a TensorArray creates an extra copy, but can calculate |
| # gradients more efficiently than StridedSlice. |
| ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) |
| iter_ = ta.unstack(iter_) |
| |
| iterate_index = 0 |
| |
| def aug_get_state(): |
| return (iterate_index,) + get_state() |
| |
| def aug_set_state(aug_loop_vars): |
| nonlocal iterate_index |
| # TODO(b/171479293): Drop the lint override. |
| iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable |
| # The iteration index is not "output" by the for loop. If the iterate |
| # is used outside the loop, it will appear in the loop vars separately. |
| set_state(loop_vars) |
| |
| def aug_body(): |
| nonlocal iterate_index |
| body(iter_.read(iterate_index)) |
| iterate_index += 1 |
| |
| def aug_test(): |
| main_test = iterate_index < n |
| if extra_test is not None: |
| return control_flow_ops.cond(main_test, extra_test, lambda: False) |
| return main_test |
| |
| _add_max_iterations_hint(opts, n) |
| |
| _tf_while_stmt( |
| aug_test, |
| aug_body, |
| aug_get_state, |
| aug_set_state, |
| ('<internal iterate>',) + symbol_names, |
| opts, |
| ) |
| |
| |
| def _tf_ragged_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of for_stmt that iterates over TF ragged tensors.""" |
| init_vars = get_state() |
| _verify_loop_init_vars(init_vars, symbol_names) |
| |
| # TODO(mdan): Move this into len()? Requires eager support. |
| if iter_.shape and iter_.shape[0] is not None: |
| n = iter_.shape[0] |
| else: |
| n = iter_.row_lengths()[0] |
| |
| iterate_index = 0 |
| |
| def aug_get_state(): |
| return (iterate_index,) + get_state() |
| |
| def aug_set_state(aug_loop_vars): |
| nonlocal iterate_index |
| # TODO(b/171479293): Drop the lint override. |
| iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable |
| # The iteration index is not "output" by the for loop. If the iterate |
| # is used outside the loop, it will appear in the loop vars separately. |
| set_state(loop_vars) |
| |
| def aug_body(): |
| nonlocal iterate_index |
| body(iter_[iterate_index]) |
| iterate_index += 1 |
| |
| def aug_test(): |
| main_test = iterate_index < n |
| if extra_test is not None: |
| return control_flow_ops.cond(main_test, extra_test, lambda: False) |
| return main_test |
| |
| _add_max_iterations_hint(opts, n) |
| |
| _tf_while_stmt( |
| aug_test, |
| aug_body, |
| aug_get_state, |
| aug_set_state, |
| ('<internal iterate>',) + symbol_names, |
| opts) |
| |
| |
| def _tf_range_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of for_stmt that iterates over a TF range (and elides it).""" |
| start, limit, delta = iter_.op.inputs |
| |
| iterate = start |
| |
| def _value_or(name, var, default): |
| if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): |
| return default |
| return var |
| |
| def aug_get_state(): |
| state_vars = get_state() |
| state_vars = tuple( |
| _value_or(name, var, iterate) |
| for name, var in zip(symbol_names, state_vars)) |
| return (iterate,) + state_vars |
| |
| def aug_set_state(aug_loop_vars): |
| nonlocal iterate |
| # TODO(b/171479293): Drop the lint override. |
| iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable |
| # The iteration index is not "output" by the for loop. If the iterate |
| # is used outside the loop, it will appear in the loop vars separately. |
| set_state(loop_vars) |
| |
| def aug_body(): |
| nonlocal iterate |
| body(iterate) |
| iterate += delta |
| |
| def aug_test(): |
| # TODO(b/159713842): Remove once constant folding works. |
| const_delta = tensor_util.constant_value(delta) |
| if const_delta is not None: |
| if const_delta >= 0: |
| main_test = iterate < limit |
| else: |
| main_test = iterate > limit |
| else: |
| main_test = math_ops.logical_or( |
| math_ops.logical_and(delta >= 0, iterate < limit), |
| math_ops.logical_and(delta < 0, iterate > limit)) |
| |
| if extra_test is not None: |
| main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) |
| return main_test |
| |
| _add_max_iterations_hint( |
| opts, |
| math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) |
| |
| _tf_while_stmt( |
| aug_test, |
| aug_body, |
| aug_get_state, |
| aug_set_state, |
| ('<internal iterate>',) + symbol_names, |
| opts) |
| |
| |
| def _tf_iterator_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" |
| symbol_names = ('<internal has_next>',) + symbol_names |
| has_next = True |
| |
| def aug_get_state(): |
| return (has_next,) + get_state() |
| |
| def aug_set_state(aug_loop_vars): |
| nonlocal has_next |
| # TODO(b/171479293): Drop the lint override. |
| has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable |
| set_state(loop_vars) |
| |
| init_vars = aug_get_state() |
| _verify_loop_init_vars(init_vars, symbol_names) |
| |
| def aug_body(): |
| """Main body passed to _tf_while_stmt.""" |
| nonlocal has_next |
| opt_iterate = iter_.get_next_as_optional() |
| has_next = opt_iterate.has_value() |
| loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. |
| |
| def main_path(): |
| body(opt_iterate.get_value()) |
| new_loop_vars = aug_get_state() |
| # Note: this verification duplicates the one performed in tf_while_stmt, |
| # but needs to be done earlier to prevent the tf.cond from blowing up |
| # first. |
| _verify_tf_loop_vars( |
| init_vars, loop_vars, new_loop_vars, symbol_names, opts) |
| return new_loop_vars |
| |
| def noop_path(): |
| return loop_vars |
| |
| # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. |
| # Calling set_state so that get_state() _tf_while_loop sees the conditional |
| # tensors. |
| aug_set_state( |
| control_flow_ops.cond(has_next, main_path, noop_path)) |
| |
| def aug_test(): |
| # This value takes a complicated path to get here: |
| # prev_iteration_body -> get_state -> tf.while_loop (as loop var) |
| # -> current_iteration_body -> set_state -> has_next |
| main_test = has_next |
| if extra_test is not None: |
| return control_flow_ops.cond(main_test, extra_test, lambda: False) |
| return main_test |
| |
| _tf_while_stmt( |
| aug_test, |
| aug_body, |
| aug_get_state, |
| aug_set_state, |
| symbol_names, |
| opts) |
| |
| |
| def _general_purpose_scan(ds, init_state, body): |
| """Variant of Dataset.scan with semantics of general-purpose computation.""" |
| # Datasets are typically intended for data preprocessing. However, in |
| # autograph loops they usually appear as general-purpose computations (for |
| # example, a custom training loop). These two use cases require significantly |
| # different optimization policies, the most important of which is the device |
| # placement. The flag override for use_default_device below instructs the |
| # runtime to treat the computation as general-purpose, rather than data |
| # preprocessing. |
| # TODO(mdan): s/use_default_device/specialize_for_input_pipeline. |
| # TODO(mdan): Don't use private symbols. |
| # pylint:disable=protected-access |
| return dataset_ops._ScanDataset( |
| ds, init_state, body, use_default_device=False) |
| |
| |
| def _tf_dataset_for_stmt( |
| ds, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" |
| # Note: This is easier to follow with the insight that the computations in |
| # a dataset pipeline are transposed (aka fused). |
| # For example, given a pipeline input -> scan -> take_while -> reduce, |
| # and a dataset with input [1, 2, 3], the computations occur in the following |
| # order: |
| # reduce(take_while(scan(1))) |
| # reduce(take_while(scan(2))) |
| # reduce(take_while(scan(3))) |
| |
| init_vars = get_state() |
| _verify_loop_init_vars(init_vars, symbol_names) |
| |
| # Workaround for Dataset.reduce not allowing empty state tensors - create |
| # a dummy state variable that remains unused. |
| # TODO(mdan): reduce should allow and match empty structures. |
| if not init_vars: |
| init_vars = (constant_op.constant(0),) |
| symbol_names = ('<internal dummy>',) |
| |
| def dummy_set_state(unused_dummy): |
| pass |
| |
| def dummy_get_state(): |
| return (constant_op.constant(0),) |
| |
| get_state, set_state = dummy_get_state, dummy_set_state |
| |
| def scan_body(scan_state, scan_inputs): |
| """Main body of the Dataset.scan.""" |
| loop_vars, iterate = scan_state, scan_inputs |
| set_state(loop_vars) |
| |
| def main_path(): |
| body(iterate) |
| new_loop_vars = get_state() |
| _verify_tf_loop_vars( |
| init_vars, loop_vars, new_loop_vars, symbol_names, opts, |
| check_shapes=False) |
| return new_loop_vars |
| |
| if extra_test is not None: |
| extra_cond = extra_test() |
| new_loop_vars = control_flow_ops.cond( |
| extra_cond, main_path, lambda: loop_vars) |
| else: |
| # TODO(mdan): the optimizer should be able to remove an invariant cond? |
| extra_cond = (constant_op.constant(True),) # dummy value, unused |
| new_loop_vars = main_path() |
| |
| scan_outputs = new_loop_vars, extra_cond |
| new_scan_state = new_loop_vars |
| return new_scan_state, scan_outputs |
| |
| def take_while_predicate(unused_loop_vars, extra_cond): |
| return extra_cond |
| |
| def reduce_body(unused_reduce_state, scan_outputs): |
| output_loop_vars, unused_extra_cond = scan_outputs |
| new_reduce_state = output_loop_vars |
| return new_reduce_state |
| |
| ds = _general_purpose_scan(ds, init_vars, scan_body) |
| if extra_test is not None: |
| ds = ds.apply(take_while_ops.take_while(take_while_predicate)) |
| final_loop_vars = ds.reduce(init_vars, reduce_body) |
| set_state(final_loop_vars) |
| |
| |
| def _tf_distributed_iterable_for_stmt( |
| iter_, extra_test, body, get_state, set_state, symbol_names, opts): |
| """Overload of for_stmt that iterates over TF distributed datasets.""" |
| |
| if extra_test is not None: |
| raise NotImplementedError( |
| 'break and return statements are not yet supported in ' |
| 'for ... in distributed input loops.') |
| |
| init_vars = get_state() |
| _verify_loop_init_vars(init_vars, symbol_names) |
| |
| if 'shape_invariants' in opts: |
| opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( |
| opts['shape_invariants'], init_vars) |
| |
| def reduce_body(loop_vars, iterate): |
| set_state(loop_vars) |
| body(iterate) |
| new_loop_vars = get_state() |
| _verify_tf_loop_vars( |
| init_vars, loop_vars, new_loop_vars, symbol_names, opts) |
| return new_loop_vars |
| |
| set_state(iter_.reduce(init_vars, reduce_body)) |
| |
| |
| def while_stmt(test, body, get_state, set_state, symbol_names, opts): |
| """Functional form of a while statement. |
| |
| The loop operates on a so-called state, which includes all symbols that are |
| variant across loop iterations. In what follows we refer to state as either |
| a tuple of entities that represent an actual state, or a list of arguments |
| of the corresponding types. |
| |
| The inputs and outputs of the callables representing the loop blocks are not |
| explicit - instead, these functions must use nonlocal/global for side effects. |
| The inputs and outputs are instead controlled by the set_state/get_state |
| functions. |
| |
| Args: |
| test: Callable with boolean return type. The loop condition. |
| body: Callable representing the actual loop body. |
| get_state: Additional callable which can capture additional state (such as |
| the values of composite symbols). This is only useful when staging the |
| loop. |
| set_state: Additional callable which save values captured by get_state back |
| into the Python environment. This is only useful when staging the loop. |
| symbol_names: Tuple containing the names of all loop variables. |
| opts: Optional dict of extra loop parameters. |
| |
| Returns: |
| Tuple containing the final state. |
| """ |
| |
| # Evaluate the initial test once in order to do the dispatch. The evaluation |
| # is isolated to minimize unwanted side effects. |
| # TODO(mdan): Do a full iteration - some state types might lower to Tensor. |
| with func_graph.FuncGraph('tmp').as_default(): |
| init_test = test() |
| |
| # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine |
| # with the re-evaluation of `test` that `_tf_while_stmt` will make. |
| if tensors.is_dense_tensor(init_test): |
| _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts) |
| return |
| |
| # Normal Python: We already consumed one evaluation of `test`; consistently, |
| # unroll one iteration before dispatching to a normal loop. |
| # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? |
| if not init_test: |
| return |
| body() |
| |
| _py_while_stmt(test, body, get_state, set_state, opts) |
| |
| |
| class _PythonLoopChecker(object): |
| """Verifies Python loops for TF-specific limits.""" |
| |
| __slots__ = ( |
| 'iterations', |
| 'check_inefficient_unroll', |
| 'check_op_count_after_iteration', |
| 'ops_before_iteration', |
| ) |
| |
| def __init__(self): |
| self.iterations = 1 |
| self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL |
| |
| # Triggered when we decided to test the op counts. |
| self.check_op_count_after_iteration = False |
| |
| def _get_ops(self): |
| return ops.get_default_graph().get_operations() |
| |
| def _check_unroll_limits(self): |
| if self.iterations > PYTHON_MAX_ITERATIONS: |
| raise ValueError('iteration limit exceeded') |
| |
| def _stop_checking_inefficient_unroll(self): |
| self.check_inefficient_unroll = False |
| self.check_op_count_after_iteration = False |
| self.ops_before_iteration = None |
| |
| def _verify_inefficient_unroll(self): |
| """Checks for possibly-inefficient creation of ops in a Python loop.""" |
| assert self.ops_before_iteration is not None |
| ops_after_iteration = self._get_ops() |
| new_ops = tuple( |
| op for op in ops_after_iteration if op not in self.ops_before_iteration) |
| |
| if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: |
| return False |
| |
| ag_logging.warning( |
| 'Large unrolled loop detected. Did you mean to use a TF loop?' |
| ' The following ops were created after iteration %s: %s' |
| '\nSee' |
| ' https://github.com/tensorflow/tensorflow/blob/master/' |
| 'tensorflow/python/autograph/g3doc/reference/common_errors.md' |
| '#warning-large-unrolled-loop-detected' |
| '\n' |
| 'Location:' |
| '\n%s' |
| '', self.iterations, new_ops, '\n'.join(traceback.format_stack())) |
| return True |
| |
| def before_iteration(self): |
| """Called before each iteration in a Python loop.""" |
| if (self.check_inefficient_unroll and |
| self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): |
| self.ops_before_iteration = self._get_ops() |
| self.check_op_count_after_iteration = True |
| |
| def after_iteration(self): |
| """Called after each iteration in a Python loop.""" |
| self.iterations += 1 |
| |
| self._check_unroll_limits() |
| |
| if self.check_op_count_after_iteration: |
| did_warn = self._verify_inefficient_unroll() |
| if did_warn: |
| self._stop_checking_inefficient_unroll() # Only warn once. |
| elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: |
| # Once deciding to check the op counts, only do it for a few iterations. |
| self._stop_checking_inefficient_unroll() |
| |
| |
| def _py_while_stmt(test, body, get_state, set_state, opts): |
| """Overload of while_stmt that executes a Python while loop.""" |
| del opts, get_state, set_state |
| |
| if __debug__: |
| checker = _PythonLoopChecker() |
| before_iteration = checker.before_iteration |
| after_iteration = checker.after_iteration |
| before_iteration() |
| |
| original_body = body |
| def protected_body(): |
| original_body() |
| after_iteration() |
| before_iteration() |
| body = protected_body |
| |
| def guarded_test(): |
| test_result = test() |
| try: |
| # Note: Using try/except and not tensor_util.is_tf_type to avoid |
| # performance degradation. |
| return bool(test_result) |
| except errors_impl.OperatorNotAllowedInGraphError: |
| ag_logging.log( |
| 1, |
| 'Caught error while evaluating while loop condition', |
| exc_info=True) |
| # TODO(mdan): distinguish beteen these two cases. |
| raise NotImplementedError( |
| 'The condition of while loop started as non-Tensor, then changed to' |
| ' Tensor. This may happen either because variables changed type, or' |
| ' when a break or return statement inside the loop depends on a' |
| ' Tensor condition. In both cases, changing to a TF loop should' |
| ' remove the error.\nSee ' |
| 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' |
| 'python/autograph/g3doc/reference/limitations.md' |
| '#consistency-of-control-flow-types for more info.') |
| while guarded_test(): |
| body() |
| |
| |
| def _shape_invariants_mapping_to_positional_list(mapping, keys): |
| # The keys are not expected to be hashable. |
| mapping = {id(k): (k, v) for k, v in mapping} |
| result = [] |
| for k in keys: |
| map_key, map_val = mapping.get(id(k), (None, None)) |
| result.append(map_val if map_key is k else None) |
| return tuple(result) |
| |
| |
| # Textual description of what a legal TF loop variable is. This description |
| # summarizes types that _placeholder_value below can handle. Keep the two |
| # together and in sync. |
| LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof' |
| |
| |
| def _placeholder_value(like, shape_invariant, original=None): |
| """Constructs a (dummy) placeholder value for a loop-initialized variable. |
| |
| Args: |
| like: Any object. The value created by the first iteration of the loop. |
| If a Python scalar, the placeholder will be the zero value of that type. |
| If a Tensor, the placeholder will be a zero tensor of matching shape and |
| dtype. If a list, dict or tuple, the placeholder will be an identical |
| structure of placeholders. |
| shape_invariant: The shape invariant specified by the user (or None, if |
| nothing was specified) for the respective variable. |
| original: Any object. The value of the variable prior to entering the loop. |
| Typically, this is one of the special "Undefined" value, because that's |
| when a placeholder is needed. |
| Returns: |
| Either a zero value of structure, shape and dtype mathing 'like', or |
| 'original', if no such zero value could be created. |
| """ |
| if isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)): |
| return original, None |
| |
| elif isinstance(like, (int, float, bool)): |
| return type(like)(0), None |
| |
| elif tensor_util.is_tf_type(like): |
| |
| like_shape = shape_invariant if shape_invariant is not None else like.shape |
| if like_shape is None or like_shape.rank is None: |
| return array_ops.zeros((), like.dtype), like_shape |
| |
| # If the shape contains dynamic values, set the corresponding starting |
| # dimension to either zero or what the shape invariant specified. |
| placeholder_shape = [] |
| has_dynamic_dims = False |
| for s, i in zip(like.shape, like_shape): |
| if i is None: |
| like_dim = 0 |
| elif isinstance(i, tensor_shape.Dimension): |
| if i.value is None: |
| like_dim = 0 |
| else: |
| like_dim = i.value |
| else: |
| like_dim = i |
| |
| if s is None: |
| placeholder_shape.append(like_dim) |
| has_dynamic_dims = True |
| elif isinstance(s, tensor_shape.Dimension): |
| if s.value is None: |
| placeholder_shape.append(like_dim) |
| has_dynamic_dims = True |
| else: |
| placeholder_shape.append(s.value) |
| else: |
| placeholder_shape.append(s) |
| |
| if has_dynamic_dims: |
| invariant = like_shape |
| else: |
| invariant = None |
| |
| return array_ops.zeros(placeholder_shape, like.dtype), invariant |
| |
| elif isinstance(like, (list, tuple, dict)): |
| if shape_invariant is None: |
| zipped = nest.map_structure(lambda v: _placeholder_value(v, None), |
| nest.flatten(like)) |
| else: |
| zipped = nest.map_structure(_placeholder_value, nest.flatten(like), |
| nest.flatten(shape_invariant)) |
| vals, invars = zip(*zipped) |
| return (nest.pack_sequence_as(like, |
| vals), nest.pack_sequence_as(like, invars)) |
| |
| return original, None |
| |
| |
| def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, |
| shape_invariants, symbol_names): |
| """Makes a best-effort attempt to substitute undefineds with placeholders. |
| |
| Note: this substitution requires two things to happen: |
| 1. the types of loop variables could be inferred (usually by staging one |
| iteration) |
| 2. these types could be replaced by placeholders (e.g. zero values, for |
| tensors. |
| |
| Args: |
| body: a function representing the loop body. See while_stmt. |
| get_state: state getter for the loop statement. See while_stmt. |
| set_state: state getter for the loop statement. See while_stmt. |
| init_vars: loop variables before entering the loop. See while_stmt. |
| nulls: list of boolean flags indicating whether the corresponding loop |
| var is None or undefined. |
| shape_invariants: user-specified shape invariant for each loop variable. |
| symbol_names: list of loop variable names. See while_stmt. |
| Returns: |
| A tuple (success, new_init_vars). success is a boolean flag indicating |
| whether types could be successfully inferred (step 1 above). new_init_vars |
| contains the loop vars, with None or undefined values replaced by |
| placeholders, where possible (step 2 above). |
| """ |
| state_modified = False |
| |
| try: |
| # Stage an iteration of the loop body in a temporary graph. |
| with func_graph.FuncGraph('tmp').as_default(): |
| # This call to set_state helps report nicer error messages when symbols |
| # are inconsistently used. |
| # Another complication is that non_tensor values will be autocast to |
| # Tensor by while_loop, and their static value lost. So we need to account |
| # that here. |
| def autocast_to_tensor(v): |
| if isinstance( |
| v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)): |
| init_val = ops.convert_to_tensor_v2(v) |
| return array_ops.placeholder(init_val.dtype, init_val.shape) |
| return v |
| autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars) |
| set_state(autocast_init_vars) |
| state_modified = True |
| |
| body() |
| first_iter_vars = get_state() |
| except (UnboundLocalError, TypeError, ValueError, KeyError): |
| ag_logging.log(1, 'Caught error while staging loop body', exc_info=True) |
| # Fall back to the old functionality. It will likely result in an input |
| # validation failure. |
| first_iter_vars = None |
| finally: |
| if state_modified: |
| set_state(init_vars) |
| |
| if first_iter_vars is not None: |
| # Note: the actual placeholder value doesn't matter, because as the staging |
| # proved, it will be replaced by an actual value before being read. |
| inits_and_invariants = tuple( |
| (_placeholder_value(iv, i, v) if n else (v, None)) |
| for v, n, iv, i in zip(init_vars, nulls, first_iter_vars, |
| shape_invariants)) |
| init_vars, extra_shape_invariants = zip(*inits_and_invariants) |
| success = True |
| else: |
| success = False |
| |
| # This check runs regardless, in case we captured non-Tensor inputs. |
| _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars) |
| |
| return success, init_vars, extra_shape_invariants |
| |
| |
| def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars): |
| """Creates an error message asking for the loop to iterate at least once.""" |
| var_names = [] |
| for sn, n, v in zip(symbol_names, nulls, init_vars): |
| if not n: |
| continue |
| if isinstance(v, variables.UndefinedReturnValue): |
| var_names.append('the function return value') |
| else: |
| var_names.append(sn) |
| var_names = ', '.join(var_names) |
| return 'loop must iterate at least once to initialize {}'.format(var_names) |
| |
| |
| def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): |
| """Overload of while_stmt that stages a TF while_stmt.""" |
| init_vars = get_state() |
| orig_init_vars = init_vars |
| |
| nulls = tuple(_is_none_or_undef(v) for v in init_vars) |
| if any(nulls): |
| shape_invars_by_init_vals = { |
| id(v): i for v, i in opts.get('shape_invariants', ()) |
| } |
| shape_invariants = tuple( |
| shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars) |
| (require_one_iteration, init_vars, |
| extra_shape_invariants) = _try_handling_undefineds(body, get_state, |
| set_state, init_vars, |
| nulls, shape_invariants, |
| symbol_names) |
| else: |
| require_one_iteration = False |
| |
| if require_one_iteration: |
| merged_shape_invariants = dict(shape_invars_by_init_vals) |
| # This has two roles: |
| # 1. Shape invariants are remapped from the old init vars to the new ones. |
| # 2. Any new shape invariants created by the init vars are kept, but only |
| # if the user didn't already specified some. |
| for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants): |
| merged_invariant = merged_shape_invariants.get(id(v), ni) |
| if merged_invariant is not None: |
| merged_shape_invariants[id(nv)] = merged_invariant |
| merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)]) |
| for nv in init_vars |
| if id(nv) in merged_shape_invariants) |
| if merged_shape_invariants: |
| opts = dict(**opts) |
| opts['shape_invariants'] = merged_shape_invariants |
| |
| def aug_test(*loop_vars): |
| if require_one_iteration: |
| loop_vars = loop_vars[1:] |
| |
| set_state(loop_vars) |
| return _verify_tf_condition(test(), 'while loop') |
| |
| def aug_body(*loop_vars): |
| if require_one_iteration: |
| loop_vars = loop_vars[1:] |
| |
| set_state(loop_vars) |
| body() |
| new_loop_vars = get_state() |
| _verify_tf_loop_vars( |
| init_vars, loop_vars, new_loop_vars, symbol_names, opts) |
| |
| if require_one_iteration: |
| new_loop_vars = (True,) + new_loop_vars |
| |
| return new_loop_vars |
| |
| if 'shape_invariants' in opts: |
| opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( |
| opts['shape_invariants'], init_vars) |
| |
| while_loop_opts = dict(opts) |
| while_loop_opts.pop('iterate_names', None) |
| |
| # Non-v2 while_loop unpacks the results when there is only one return value. |
| # This enforces consistency across versions. |
| while_loop_opts['return_same_structure'] = True |
| |
| if require_one_iteration: |
| aug_init_vars = (False,) + init_vars |
| if 'shape_invariants' in while_loop_opts: |
| while_loop_opts['shape_invariants'] = ( |
| (None,) + while_loop_opts['shape_invariants']) |
| else: |
| aug_init_vars = init_vars |
| |
| final_loop_vars = control_flow_ops.while_loop( |
| aug_test, aug_body, aug_init_vars, **while_loop_opts) |
| |
| if require_one_iteration: |
| with ops.control_dependencies([ |
| control_flow_ops.Assert(final_loop_vars[0], [ |
| _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars) |
| ]) |
| ]): |
| final_loop_vars = nest.map_structure( |
| lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v), |
| final_loop_vars[1:], |
| ) |
| |
| set_state(final_loop_vars) |
| |
| |
| def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): |
| """Functional form of an if statement. |
| |
| The conditional operates on a state, which includes all symbols whose values |
| are a function of the branch taken. |
| |
| For example, given the code below that calculates the abs function: |
| |
| ``` |
| x = 1 |
| if x > 0: |
| x = -x |
| ``` |
| |
| The state is represented by the variable `x`. The `body, `orelse` and |
| `set_state` functions must bind to the original `x` symbol, using `nonlocal`. |
| |
| The inputs and outputs of the callables representing the loop blocks are not |
| explicit - instead, these functions must use nonlocal/global for side effects. |
| The inputs and outputs are instead controlled by the set_state/get_state |
| functions. |
| |
| Args: |
| cond: Boolean. |
| body: Callable representing the main block of the conditional. |
| orelse: Callable representing the else block of the conditional. |
| get_state: Function that returns a tuple containing the values of all |
| composite symbols modified within the conditional. This allows access to |
| state that branches may mutate through side effects. This function is not |
| needed and should not be called when dispatching to code matching Python's |
| default semantics. This is useful for checkpointing to avoid unintended |
| side-effects when staging requires evaluating all code-paths. |
| set_state: Function to set the values of all composite symbols modified |
| within the conditional. This is the complement to get_state, used to |
| restore checkpointed values. The single argument a tuple containing values |
| for each composite symbol that may be modified in a branch of the |
| conditional. The is usually the result of a call to get_state. |
| symbol_names: Tuple containing basic loop var names. |
| nouts: Number of variables output by the statement. Vars which are |
| not outputs will not be passed through staged control flow such as |
| tf.cond. This includes variables that are defined before the conditional, |
| but are not used after it. |
| """ |
| # Note: tf.cond doesn't support SparseTensor. |
| if tensors.is_dense_tensor(cond): |
| _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) |
| else: |
| _py_if_stmt(cond, body, orelse) |
| |
| |
| def _tf_if_stmt( |
| cond, body, orelse, get_state, set_state, symbol_names, nouts): |
| """Overload of if_stmt that stages a TF cond.""" |
| cond = _verify_tf_condition(cond, 'if statement') |
| |
| if not nouts: |
| prev_get_state, prev_set_state = get_state, set_state |
| # Control flow V1 wants at least one output. |
| get_state = lambda: (0,) + prev_get_state() |
| set_state = lambda v: prev_set_state(v[1:]) |
| symbol_names += ('<unused dummy>',) |
| nouts = 1 |
| |
| init_vars = get_state() |
| |
| # TODO(mdan): Use nonlocal once we no longer need to support py2. |
| new_body_vars_ = [None] |
| new_orelse_vars_ = [None] |
| |
| def aug_body(): |
| set_state(init_vars) |
| body() |
| new_body_vars = get_state() |
| new_body_vars = new_body_vars[:nouts] |
| new_body_vars_[0] = new_body_vars |
| _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') |
| if new_orelse_vars_[0] is not None: |
| _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) |
| return new_body_vars |
| |
| def aug_orelse(): |
| set_state(init_vars) |
| orelse() |
| new_orelse_vars = get_state() |
| new_orelse_vars = new_orelse_vars[:nouts] |
| new_orelse_vars_[0] = new_orelse_vars |
| _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') |
| if new_body_vars_[0] is not None: |
| _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) |
| return new_orelse_vars |
| |
| final_cond_vars = control_flow_ops.cond( |
| cond, aug_body, aug_orelse, strict=True) |
| final_cond_vars = final_cond_vars + init_vars[nouts:] |
| |
| set_state(final_cond_vars) |
| |
| |
| def _py_if_stmt(cond, body, orelse): |
| """Overload of if_stmt that executes a Python if statement.""" |
| return body() if cond else orelse() |