| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """TensorFlow-related utilities.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import six |
| |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import smart_cond as smart_module |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.keras import backend as K |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.util import tf_contextlib |
| |
| |
| def smart_cond(pred, true_fn=None, false_fn=None, name=None): |
| """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. |
| |
| If `pred` is a bool or has a constant value, we return either `true_fn()` |
| or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. |
| |
| Arguments: |
| pred: A scalar determining whether to return the result of `true_fn` or |
| `false_fn`. |
| true_fn: The callable to be performed if pred is true. |
| false_fn: The callable to be performed if pred is false. |
| name: Optional name prefix when using `tf.cond`. |
| |
| Returns: |
| Tensors returned by the call to either `true_fn` or `false_fn`. |
| |
| Raises: |
| TypeError: If `true_fn` or `false_fn` is not callable. |
| """ |
| if isinstance(pred, variables.Variable): |
| return control_flow_ops.cond( |
| pred, true_fn=true_fn, false_fn=false_fn, name=name) |
| return smart_module.smart_cond( |
| pred, true_fn=true_fn, false_fn=false_fn, name=name) |
| |
| |
| def constant_value(pred): |
| """Return the bool value for `pred`, or None if `pred` had a dynamic value. |
| |
| Arguments: |
| pred: A scalar, either a Python bool or a TensorFlow boolean variable |
| or tensor, or the Python integer 1 or 0. |
| |
| Returns: |
| True or False if `pred` has a constant boolean value, None otherwise. |
| |
| Raises: |
| TypeError: If `pred` is not a Variable, Tensor or bool, or Python |
| integer 1 or 0. |
| """ |
| # Allow integer booleans. |
| if isinstance(pred, int): |
| if pred == 1: |
| pred = True |
| elif pred == 0: |
| pred = False |
| |
| if isinstance(pred, variables.Variable): |
| return None |
| return smart_module.smart_constant_value(pred) |
| |
| |
| def is_tensor_or_tensor_list(v): |
| v = nest.flatten(v) |
| if v and isinstance(v[0], ops.Tensor): |
| return True |
| else: |
| return False |
| |
| |
| def get_reachable_from_inputs(inputs, targets=None): |
| """Returns the set of tensors/ops reachable from `inputs`. |
| |
| Stops if all targets have been found (target is optional). |
| |
| Only valid in Symbolic mode, not Eager mode. |
| |
| Args: |
| inputs: List of tensors. |
| targets: List of tensors. |
| |
| Returns: |
| A set of tensors reachable from the inputs (includes the inputs themselves). |
| """ |
| inputs = nest.flatten(inputs, expand_composites=True) |
| reachable = object_identity.ObjectIdentitySet(inputs) |
| if targets: |
| remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets)) |
| queue = inputs[:] |
| |
| while queue: |
| x = queue.pop() |
| if isinstance(x, tuple(_user_convertible_tensor_types)): |
| # Can't find consumers of user-specific types. |
| continue |
| |
| if isinstance(x, ops.Operation): |
| outputs = x.outputs[:] or [] |
| outputs += x._control_outputs # pylint: disable=protected-access |
| elif isinstance(x, variables.Variable): |
| try: |
| outputs = [x.op] |
| except AttributeError: |
| # Variables can be created in an Eager context. |
| outputs = [] |
| elif tensor_util.is_tensor(x): |
| outputs = x.consumers() |
| else: |
| raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) |
| |
| for y in outputs: |
| if y not in reachable: |
| reachable.add(y) |
| if targets: |
| remaining_targets.discard(y) |
| queue.insert(0, y) |
| |
| if targets and not remaining_targets: |
| return reachable |
| |
| return reachable |
| |
| |
| # This function needs access to private functions of `nest`. |
| # pylint: disable=protected-access |
| def map_structure_with_atomic(is_atomic_fn, map_fn, nested): |
| """Maps the atomic elements of a nested structure. |
| |
| Arguments: |
| is_atomic_fn: A function that determines if an element of `nested` is |
| atomic. |
| map_fn: The function to apply to atomic elements of `nested`. |
| nested: A nested structure. |
| |
| Returns: |
| The nested structure, with atomic elements mapped according to `map_fn`. |
| |
| Raises: |
| ValueError: If an element that is neither atomic nor a sequence is |
| encountered. |
| """ |
| if is_atomic_fn(nested): |
| return map_fn(nested) |
| |
| # Recursively convert. |
| if not nest.is_sequence(nested): |
| raise ValueError( |
| 'Received non-atomic and non-sequence element: {}'.format(nested)) |
| if nest._is_mapping(nested): |
| values = [nested[k] for k in nest._sorted(nested)] |
| else: |
| values = nested |
| mapped_values = [ |
| map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values |
| ] |
| return nest._sequence_like(nested, mapped_values) |
| |
| |
| # pylint: enable=protected-access |
| |
| |
| def convert_shapes(input_shape, to_tuples=True): |
| """Converts nested shape representations to desired format. |
| |
| Performs: |
| |
| TensorShapes -> tuples if `to_tuples=True`. |
| tuples of int or None -> TensorShapes if `to_tuples=False`. |
| |
| Valid objects to be converted are: |
| - TensorShapes |
| - tuples with elements of type int or None. |
| - ints |
| - None |
| |
| Arguments: |
| input_shape: A nested structure of objects to be converted to TensorShapes. |
| to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts |
| all tuples representing shapes to TensorShapes. |
| |
| Returns: |
| Nested structure of shapes in desired format. |
| """ |
| |
| def _is_shape_component(value): |
| return value is None or isinstance(value, (int, tensor_shape.Dimension)) |
| |
| def _is_atomic_shape(input_shape): |
| # Ex: TensorShape or (None, 10, 32) or 5 or `None` |
| if _is_shape_component(input_shape): |
| return True |
| if isinstance(input_shape, tensor_shape.TensorShape): |
| return True |
| if (isinstance(input_shape, (tuple, list)) and |
| all(_is_shape_component(ele) for ele in input_shape)): |
| return True |
| return False |
| |
| def _convert_shape(input_shape): |
| input_shape = tensor_shape.TensorShape(input_shape) |
| if to_tuples: |
| input_shape = tuple(input_shape.as_list()) |
| return input_shape |
| |
| return map_structure_with_atomic(_is_atomic_shape, _convert_shape, |
| input_shape) |
| |
| |
| class ListWrapper(object): |
| """A wrapper for lists to be treated as elements for `nest`.""" |
| |
| def __init__(self, list_to_wrap): |
| self._list = list_to_wrap |
| |
| def as_list(self): |
| return self._list |
| |
| |
| def convert_inner_node_data(nested, wrap=False): |
| """Either wraps or unwraps innermost node data lists in `ListWrapper` objects. |
| |
| Arguments: |
| nested: A nested data structure. |
| wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`, |
| unwraps `ListWrapper` objects into lists. |
| |
| Returns: |
| Structure of same type as nested, with lists wrapped/unwrapped. |
| """ |
| |
| def _is_serialized_node_data(nested): |
| # Node data can be of form `[layer_name, node_id, tensor_id]` or |
| # `[layer_name, node_id, tensor_id, kwargs]`. |
| if (isinstance(nested, list) and (len(nested) in [3, 4]) and |
| isinstance(nested[0], six.string_types)): |
| return True |
| return False |
| |
| def _is_atomic_nested(nested): |
| """Returns `True` if `nested` is a list representing node data.""" |
| if isinstance(nested, ListWrapper): |
| return True |
| if _is_serialized_node_data(nested): |
| return True |
| return not nest.is_sequence(nested) |
| |
| def _convert_object_or_list(nested): |
| """Convert b/t `ListWrapper` object and list representations.""" |
| if wrap: |
| if isinstance(nested, ListWrapper): |
| return nested |
| if _is_serialized_node_data(nested): |
| return ListWrapper(nested) |
| return nested |
| else: |
| if isinstance(nested, ListWrapper): |
| return nested.as_list() |
| return nested |
| |
| return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list, |
| nested) |
| |
| |
| def shape_type_conversion(fn): |
| """Decorator that handles tuple/TensorShape conversion. |
| |
| Used in `compute_output_shape` and `build`. |
| |
| Arguments: |
| fn: function to wrap. |
| |
| Returns: |
| Wrapped function. |
| """ |
| |
| def wrapper(instance, input_shape): |
| # Pass shapes as tuples to `fn` |
| # This preserves compatibility with external Keras. |
| if input_shape is not None: |
| input_shape = convert_shapes(input_shape, to_tuples=True) |
| output_shape = fn(instance, input_shape) |
| # Return shapes from `fn` as TensorShapes. |
| if output_shape is not None: |
| output_shape = convert_shapes(output_shape, to_tuples=False) |
| return output_shape |
| |
| return wrapper |
| |
| |
| def are_all_symbolic_tensors(tensors): |
| return all(is_symbolic_tensor(tensor) for tensor in tensors) |
| |
| |
| _user_convertible_tensor_types = set() |
| |
| |
| def is_symbolic_tensor(tensor): |
| """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. |
| |
| A Variable can be seen as either: it is considered symbolic |
| when we are in a graph scope, and eager when we are in an eager scope. |
| |
| Arguments: |
| tensor: A tensor instance to test. |
| |
| Returns: |
| True for symbolic tensors, False for eager tensors. |
| """ |
| if isinstance(tensor, tuple(_user_convertible_tensor_types)): |
| tensor = ops.convert_to_tensor_or_composite(tensor) |
| if isinstance(tensor, variables.Variable): |
| # Variables that are output of a Keras Layer in Functional API mode |
| # should be considered symbolic. |
| # TODO(omalleyt): We need a better way to check this in order to |
| # enable `run_eagerly=True` for Models containing Layers that |
| # return Variables as outputs. |
| return (getattr(tensor, '_keras_history', False) or |
| not context.executing_eagerly()) |
| if isinstance(tensor, composite_tensor.CompositeTensor): |
| return tensor._is_graph_tensor # pylint: disable=protected-access |
| if isinstance(tensor, ops.Tensor): |
| return hasattr(tensor, 'graph') |
| return False |
| |
| |
| def register_symbolic_tensor_type(cls): |
| """Allows users to specify types regarded as symbolic `Tensor`s. |
| |
| Used in conjunction with `tf.register_tensor_conversion_function`, calling |
| `tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor` |
| objects to be plumbed through Keras layers. |
| |
| Example: |
| |
| ```python |
| # One-time setup. |
| class Foo(object): |
| def __init__(self, input_): |
| self._input = input_ |
| def value(self): |
| return tf.constant(42.) |
| |
| tf.register_tensor_conversion_function( |
| Foo, lambda x, *args, **kwargs: x.value()) |
| |
| tf.keras.utils.register_symbolic_tensor_type(Foo) |
| |
| # User-land. |
| layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) |
| ``` |
| |
| Arguments: |
| cls: A `class` type which shall be regarded as a symbolic `Tensor`. |
| """ |
| global _user_convertible_tensor_types |
| _user_convertible_tensor_types.add(cls) |
| |
| |
| def is_tensor_or_variable(x): |
| return tensor_util.is_tensor(x) or isinstance(x, variables.Variable) |
| |
| |
| def assert_no_legacy_layers(layers): |
| """Prevent tf.layers.Layers from being used with Keras. |
| |
| Certain legacy layers inherit from their keras analogs; however they are |
| not supported with keras and can lead to subtle and hard to diagnose bugs. |
| |
| Args: |
| layers: A list of layers to check |
| |
| Raises: |
| TypeError: If any elements of layers are tf.layers.Layers |
| """ |
| |
| # isinstance check for tf.layers.Layer introduces a circular dependency. |
| legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)] |
| if legacy_layers: |
| layer_str = '\n'.join([' ' + str(l) for l in legacy_layers]) |
| raise TypeError( |
| 'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a ' |
| 'framework (for instance using the Network, Model, or Sequential ' |
| 'classes), please use the tf.keras.layers implementation instead. ' |
| '(Or, if writing custom layers, subclass from tf.keras.layers rather ' |
| 'than tf.layers)'.format(layer_str)) |
| |
| |
| @tf_contextlib.contextmanager |
| def maybe_init_scope(layer): |
| """Open an `init_scope` if in V2 mode and using the keras graph. |
| |
| Arguments: |
| layer: The Layer/Model that is currently active. |
| |
| Yields: |
| None |
| """ |
| # Don't open an init_scope in V1 mode or when using legacy tf.layers. |
| if (ops.executing_eagerly_outside_functions() and |
| getattr(layer, '_keras_style', True)): |
| with ops.init_scope(): |
| yield |
| else: |
| yield |
| |
| |
| @tf_contextlib.contextmanager |
| def graph_context_for_symbolic_tensors(*args, **kwargs): |
| """Returns graph context manager if any of the inputs is a symbolic tensor.""" |
| if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())): |
| with K.get_graph().as_default(): |
| yield |
| else: |
| yield |