| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=protected-access |
| """Contains the `Node` class.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import copy |
| import json |
| import numpy as np |
| |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.keras import backend |
| from tensorflow.python.keras.engine import base_layer_utils |
| from tensorflow.python.keras.utils import tf_utils |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import serialization |
| |
| _CONSTANT_VALUE = '_CONSTANT_VALUE' |
| |
| |
| class Node(object): |
| """A `Node` describes the connectivity between two layers. |
| |
| Each time a layer is connected to some new input, |
| a node is added to `layer._inbound_nodes`. |
| Each time the output of a layer is used by another layer, |
| a node is added to `layer._outbound_nodes`. |
| |
| Arguments: |
| layer: The Layer for the Layer.__call__ this node represents. |
| call_args: The positional arguments the Layer was called with. |
| call_kwargs: The keyword arguments the Layer was called with. |
| outputs: The outputs of the Layer.__call__ |
| """ |
| |
| def __init__(self, |
| layer, |
| call_args=None, |
| call_kwargs=None, |
| outputs=None): |
| call_args = [] if call_args is None else call_args |
| call_kwargs = {} if call_kwargs is None else call_kwargs |
| outputs = [] if outputs is None else outputs |
| |
| self.layer = layer |
| self.is_input = not call_args and not call_kwargs |
| |
| # These arguments are user-provided. Copy the structures here so that |
| # future user modifications do not affect the node's metadata. |
| # We copy using map_structure rather than python's shallow or deep copy, |
| # because the args can be data structures (so shallow copy is |
| # insufficient), but individual values might not support copy.copy |
| # or be too expensive to deep copy. |
| call_args = nest.map_structure(lambda t: t, call_args) |
| call_kwargs = nest.map_structure(lambda t: t, call_kwargs) |
| self.outputs = nest.map_structure(lambda t: t, outputs) |
| self.call_args = call_args |
| self.call_kwargs = call_kwargs |
| |
| # Cached for performance. |
| self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) |
| # Used to avoid expensive `nest` operations in the most common case. |
| self._single_positional_tensor_passed = (not self.call_kwargs and len( |
| self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0])) |
| |
| # Create TensorFlowOpLayers if needed. |
| for obj in self._flat_arguments: |
| if (isinstance(obj, ops.Tensor) and |
| base_layer_utils.needs_keras_history(obj, ignore_call_context=True)): |
| base_layer_utils.create_keras_history(obj) |
| |
| self._keras_inputs = [] |
| self._keras_inputs_ids_and_indices = [] |
| for i, ele in enumerate(self._flat_arguments): |
| if is_keras_tensor(ele): |
| self._keras_inputs.append(ele) |
| kt_id = str(id(ele)) |
| kt_index = i |
| self._keras_inputs_ids_and_indices.append((kt_id, kt_index)) |
| |
| # Wire up Node to Layers. |
| self.layer._inbound_nodes.append(self) |
| for kt in self.keras_inputs: |
| inbound_layer = kt._keras_history.layer |
| if inbound_layer is not None: # `None` for `Input` tensors. |
| inbound_layer._outbound_nodes.append(self) |
| |
| # Set metadata on outputs. |
| node_index = len(self.layer._inbound_nodes) - 1 |
| for i, tensor in enumerate(nest.flatten(outputs)): |
| tensor._keras_history = KerasHistory( |
| layer=layer, node_index=node_index, tensor_index=i) |
| |
| # Cached for performance. |
| self.flat_input_ids = [str(id(t)) for t in self._keras_inputs] |
| self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)] |
| |
| @property |
| def keras_inputs(self): |
| """Tensors input to this node that can be traced back to a `keras.Input`.""" |
| return self._keras_inputs |
| |
| @property |
| def parent_nodes(self): |
| """Returns all the `Node`s whose output this node immediately depends on.""" |
| node_deps = [] |
| for kt in self.keras_inputs: |
| layer = kt._keras_history.layer |
| node_index = kt._keras_history.node_index |
| if layer is not None: # `None` for `Input` tensors. |
| node_deps.append(layer._inbound_nodes[node_index]) |
| return node_deps |
| |
| def iterate_inbound(self): |
| """Yields tuples representing the data inbound from other nodes. |
| |
| Yields: |
| tuples like: (inbound_layer, node_index, tensor_index, tensor). |
| """ |
| for kt in self.keras_inputs: |
| keras_history = kt._keras_history |
| layer = keras_history.layer |
| node_index = keras_history.node_index |
| tensor_index = keras_history.tensor_index |
| yield layer, node_index, tensor_index, kt |
| |
| def map_arguments(self, tensor_dict): |
| """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" |
| if self._single_positional_tensor_passed: |
| # Performance optimization for most common case. |
| kt_id, _ = self._keras_inputs_ids_and_indices[0] |
| return (tensor_dict[kt_id].pop(),), {} |
| else: |
| flat_arguments = copy.copy(self._flat_arguments) |
| for kt_id, kt_index in self._keras_inputs_ids_and_indices: |
| flat_arguments[kt_index] = tensor_dict[kt_id].pop() |
| |
| args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs), |
| flat_arguments) |
| return args, kwargs |
| |
| def serialize(self, make_node_key, node_conversion_map): |
| """Serializes `Node` for Functional API's `get_config`.""" |
| # Serialization still special-cases first argument. |
| args, kwargs = self.call_args, self.call_kwargs |
| inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) |
| |
| # Treat everything other than first argument as a kwarg. |
| arguments = dict(zip(self.layer._call_fn_args[1:], args)) |
| arguments.update(kwargs) |
| kwargs = arguments |
| |
| kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) |
| try: |
| json.dumps(kwargs, default=serialization.get_json_type) |
| except TypeError: |
| kwarg_types = nest.map_structure(type, kwargs) |
| logging.warning('Layer ' + self.layer.name + |
| ' was passed non-JSON-serializable arguments. ' + |
| 'Arguments had types: ' + |
| str(kwarg_types) + '. They will not be included ' |
| 'in the serialized model (and thus will be missing ' |
| 'at deserialization time).') |
| kwargs = {} |
| |
| # `kwargs` is added to each Tensor in the first arg. This should be |
| # changed in a future version of the serialization format. |
| def serialize_first_arg_tensor(t): |
| if is_keras_tensor(t): |
| kh = t._keras_history |
| node_index = kh.node_index |
| node_key = make_node_key(kh.layer.name, node_index) |
| new_node_index = node_conversion_map.get(node_key, 0) |
| data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] |
| else: |
| # If an element in the first call argument did not originate as a |
| # keras tensor and is a constant value, we save it using the format |
| # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant] |
| # (potentially including serialized kwargs in an optional 4th argument |
| data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs] |
| return tf_utils.ListWrapper(data) |
| |
| data = nest.map_structure(serialize_first_arg_tensor, inputs) |
| if not nest.is_sequence(data): |
| data = [data] |
| data = tf_utils.convert_inner_node_data(data) |
| return data |
| |
| ############################################################# |
| # Properties for Backwards compatibility. |
| # These only check the first input argument |
| # As nodes are internal, they may be removed in the future. |
| ############################################################# |
| |
| @property |
| def input_tensors(self): |
| if self.is_input: |
| return [self.outputs] # Used in `Layer.input`. |
| return self.call_args[0] |
| |
| @property |
| def output_tensors(self): |
| if self.is_input: |
| return [self.outputs] # Used in `Layer.input`. |
| return self.outputs |
| |
| @property |
| def input_shapes(self): |
| input_shapes = nest.map_structure(backend.int_shape, self.input_tensors) |
| if len(input_shapes) == 1 and not self.is_input: |
| return input_shapes[0] |
| return input_shapes |
| |
| @property |
| def output_shapes(self): |
| return nest.map_structure(backend.int_shape, self.output_tensors) |
| |
| @property |
| def outbound_layer(self): |
| return self.layer |
| |
| @property |
| def inbound_layers(self): |
| if self.is_input: |
| return [] |
| inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, |
| self.call_args[0]) |
| return inbound_layers |
| |
| |
| class KerasHistory( |
| collections.namedtuple('KerasHistory', |
| ['layer', 'node_index', 'tensor_index'])): |
| """Tracks the Layer call that created a Tensor, for Keras Graph Networks. |
| |
| During construction of Keras Graph Networks, this metadata is added to |
| each Tensor produced as the output of a Layer, starting with an |
| `InputLayer`. This allows Keras to track how each Tensor was produced, and |
| this information is later retraced by the `keras.engine.Network` class to |
| reconstruct the Keras Graph Network. |
| |
| Attributes: |
| layer: The Layer that produced the Tensor. |
| node_index: The specific call to the Layer that produced this Tensor. Layers |
| can be called multiple times in order to share weights. A new node is |
| created every time a Layer is called. |
| tensor_index: The output index for this Tensor. Always zero if the Layer |
| that produced this Tensor only has one output. Nested structures of |
| Tensors are deterministically assigned an index via `nest.flatten`. |
| """ |
| # Added to maintain memory and performance characteristics of `namedtuple` |
| # while subclassing. |
| __slots__ = () |
| |
| |
| def is_keras_tensor(obj): |
| return hasattr(obj, '_keras_history') |
| |
| |
| def _serialize_keras_tensor(t): |
| """Serializes a single Tensor passed to `call`.""" |
| if hasattr(t, '_keras_history'): |
| kh = t._keras_history |
| return [kh.layer.name, kh.node_index, kh.tensor_index] |
| |
| if isinstance(t, np.ndarray): |
| return t.tolist() |
| |
| if isinstance(t, ops.Tensor): |
| return backend.get_value(t).tolist() |
| |
| return t |