| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Define tflite op hints (intrinsic operations). |
| |
| This essentially allows defining a TensorFlow API for tflite operations in |
| Python with hints on how they are represented in TensorFlow Lite. This basically |
| is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution |
| graph and is useful for LSTMs and other complicated TensorFlow constructions |
| that are difficult to pattern match in TOCO, but are represented by a single |
| accelerated tflite op. |
| |
| Example: |
| def tflite_cool_activation(input): |
| # A cool activation function. |
| custom = tf.lite.OpHint("cool_activation") |
| input, = custom.add_inputs(input) |
| output = tf.sigmoid(input) * input |
| output, = custom.add_outputs(output) |
| return output |
| |
| image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1)) |
| output = tf.identity(tflite_cool_activation(image)) |
| |
| session = tf.compat.v1.Session() |
| |
| graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session) |
| tflite_graph = tf.compat.v1.lite.toco_convert( |
| graphdef_to_convert, [image], [output], allow_custom_ops=True) |
| with open("/tmp/graph.fb", "wb") as fp: |
| fp.write(tflite_graph) |
| |
| How does it work?: |
| |
| OpHint is a helper that you use when defining a vanilla python function. |
| It allows you to wrap arguments with tf.identities with some custom attributes. |
| These attributes allow you to find the original block of ops that was created. |
| For example, if you use cool_activation above you essentially get: |
| |
| a_input = tf.identity() |
| result = tf.multiply(tf.sigmoid(a_input), a_input) |
| output = tf.identity() |
| |
| a_input, output are identities that have parameters representing |
| what argument they are, what the name of the function they should turn into |
| in tf lite as well as a guid that uniquely identifies a particular invocation. |
| |
| Once you have built your whole tensorflow graph, you can run it and train it |
| as usual, but after you have done that, you need to convert the graph into |
| a form that replaces these subgraphs wrapped in identities to stub ops. These |
| ops don't actually exist in the normal TensorFlow runtime, but will be |
| understood by toco later. The generated TensorFlow Lite flatbuffer file will |
| contain a custom operator called "cool_activation". Developer needs to implement |
| and register this operator in TensorFlow Lite in order to do inference. |
| """ |
| |
| # TODO(aselle): Make this use generic graph transformations. |
| # TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name. |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections as _collections |
| import copy as _copy |
| import json as _json |
| import uuid as _uuid |
| import six as _six |
| |
| from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 |
| from tensorflow.core.framework import graph_pb2 as _graph_pb2 |
| from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 |
| from tensorflow.python.framework import dtypes as _dtypes |
| from tensorflow.python.framework import ops as _ops |
| from tensorflow.python.framework import tensor_util as _tensor_util |
| # TODO(aselle): publicize these apis if we continue to use these. |
| from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes |
| from tensorflow.python.framework.graph_util_impl import _extract_graph_summary |
| from tensorflow.python.ops import array_ops as _array_ops |
| from tensorflow.python.util import compat as _compat |
| from tensorflow.python.util.all_util import remove_undocumented |
| from tensorflow.python.util.tf_export import tf_export as _tf_export |
| |
| |
| @_tf_export(v1=["lite.OpHint"]) |
| class OpHint(object): |
| """A class that helps build tflite function invocations. |
| |
| It allows you to take a bunch of TensorFlow ops and annotate the construction |
| such that toco knows how to convert it to tflite. This embeds a pseudo |
| function in a TensorFlow graph. This allows embedding high-level API usage |
| information in a lower level TensorFlow implementation so that an alternative |
| implementation can be substituted later. |
| |
| Essentially, any "input" into this pseudo op is fed into an identity, and |
| attributes are added to that input before being used by the constituent ops |
| that make up the pseudo op. A similar process is done to any output that |
| is to be exported from the current op. |
| |
| """ |
| # TODO(aselle): When TensorFlow functions functionality works for arbitrary |
| # constructs, this mechanism can be retired and changed to use python defun's. |
| |
| # Attr constants that are used for representation in the GraphDef. These |
| # will be used on every Identity op that is involved in a total OpHint. |
| |
| # Name of the OpHint function (cosmetic). |
| FUNCTION_NAME_ATTR = "_tflite_function_name" |
| # UUID of the function (each OpHint gets a new uuid). |
| FUNCTION_UUID_ATTR = "_tflite_function_uuid" |
| # The input index of the input (or nothing if it is an output). |
| FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" |
| # The output index of the output (or nothing if it is an input). |
| FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" |
| # An index that orders aggregate arguments. Aggregate arguments are ones |
| # that are separate but will be fused horizontally. For example a static LSTM |
| # has a lstm cell for each time step. Each one has a separate opHint, but a |
| # fused SequentialLSTM will treat this as a single tensor. |
| FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" |
| # The way in which multiple parts of the aggregate argument will be joined |
| # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, |
| # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. |
| FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" |
| # On fused OpHint stub, the order of inputs that the final LSTM call will |
| # have. What this means is that the TensorFlow order might be |
| # "foo", "bar", "stuff" and you might want the TF lite op order to be |
| # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this |
| # attribute to [2, 0, 1, -1]. |
| TFLITE_INPUT_INDICES = "_tflite_input_indices" |
| # OpHint level. |
| FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" |
| # Ophint internal mapping, this is for high level Ophint only. |
| # This basically contains three kinds of mapping: |
| # 1) How parental ophinted inputs map to the first child ophinted inputs; |
| # 2) How internal children nodes are connected; |
| # 3) How parental ophinted outputs map to the last child ophinted outputs. |
| CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" |
| |
| # Types of aggregations |
| # stack: stacks all ophints with matching tags. i.e. for a static rnn. |
| # specifically, this is good for an input or output to a static rnn cell. |
| AGGREGATE_STACK = "stack" |
| # first: only takes the first output (one with lowest sort index) |
| # of matching tags. This is good for the input state to an RNN. |
| AGGREGATE_FIRST = "first" |
| # aggregation last takes only the last tag (one with highest sort index). |
| # This is good for an output value on the last stack item of a |
| # static rnn. |
| AGGREGATE_LAST = "last" |
| |
| class OpHintArgumentTracker(object): |
| """Conceptually tracks indices of arguments of "OpHint functions". |
| |
| The inputs and arguments of these functions both use an instance |
| of the class so they can have independent numbering. |
| """ |
| |
| def __init__(self, |
| function_name, |
| unique_function_id, |
| node_name_prefix, |
| attr_name, |
| level=1, |
| children_inputs_mappings=None): |
| """Initialize ophint argument. |
| |
| Args: |
| function_name: Name of the function that this tracks arguments for. |
| unique_function_id: UUID of function that this tracks arguments for. |
| node_name_prefix: How identities that are created are named. |
| attr_name: Name of attribute to use to store the index for this hint. |
| i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX |
| level: Hierarchical level of the Ophint node, a number. |
| children_inputs_mappings: Inputs/Outputs mapping for children hints. |
| """ |
| |
| # The global index is the argument index of the op. This is in contrast |
| # to the sort index which is the sequence number of a particular instance |
| # of a given global index. For example, you may have called add hint |
| # twice with the tag "foo". Then the global index will be 0 for both |
| # and the sort index will be 0 for the first added and 1 for the second. |
| self._function_name = function_name |
| self._unique_function_id = unique_function_id |
| self._next_global_index = 0 # The absolute global index |
| self._used_global_indices = set() |
| self._tag_to_global_index = {} # The argument index a given tag maps to |
| self._tag_to_next_sort_index = {} # The current index for each tag |
| self._node_name_prefix = node_name_prefix |
| self._attr_name = attr_name |
| self._level = level |
| self._children_inputs_mappings = children_inputs_mappings |
| |
| def _get_new_global_index(self, index_override): |
| """Return the next unused argument index in order or use an override. |
| |
| Args: |
| index_override: An index to use instead of the next available or None |
| to use the next available. |
| |
| Returns: |
| A valid global_index to use for the next hint argument. |
| |
| Raises: |
| ValueError: If the index_override is already used by another hint. |
| """ |
| if index_override is None: |
| global_index = self._next_global_index |
| else: |
| if index_override in self._used_global_indices: |
| raise ValueError("Index %d was already used by another call to add") |
| global_index = index_override |
| # Make next_global_index valid |
| self._used_global_indices.add(global_index) |
| while self._next_global_index in self._used_global_indices: |
| self._next_global_index += 1 |
| return global_index |
| |
| def add(self, arg, tag=None, name=None, aggregate=None, |
| index_override=None): |
| """Return a wrapped tensor of an input tensor as an argument. |
| |
| Args: |
| arg: A TensorFlow tensor that should be considered an argument. |
| tag: String tag to identify arguments that should be packed. |
| name: Name of argument. This is included in the Identity hint op names. |
| aggregate: Strategy to aggregate. |
| Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, |
| and OpHint.AGGREGATE_STACK. |
| Note, aggregate is only valid if tag is specified. |
| index_override: Specify what input/output index should this be in the |
| final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the |
| final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than |
| the default call order based ordering. |
| |
| Returns: |
| A tensor representing the wrapped argument. |
| |
| Raises: |
| ValueError: When indices are not consistent. |
| """ |
| |
| # Find the appropriate index |
| if tag is None: |
| if aggregate is not None: |
| raise ValueError("You must specify `tag` if using aggregate.") |
| global_index = self._get_new_global_index(index_override) |
| sort_index = None |
| else: |
| if aggregate is None: |
| raise ValueError("You must specify `aggregate` if using tag.") |
| if tag not in self._tag_to_global_index: |
| self._tag_to_global_index[tag] = ( |
| self._get_new_global_index(index_override)) |
| self._tag_to_next_sort_index[tag] = 0 |
| elif (index_override and |
| index_override != self._tag_to_global_index[tag]): |
| raise ValueError( |
| "Tag %r was called with two indices %r and %r" % |
| (tag, index_override, self._tag_to_global_index[tag])) |
| global_index = self._tag_to_global_index[tag] |
| sort_index = self._tag_to_next_sort_index[tag] |
| self._tag_to_next_sort_index[tag] += 1 |
| |
| uuid = self._unique_function_id |
| name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, |
| uuid, global_index, sort_index, name) |
| |
| identity_op = _array_ops.identity(arg, name=name) |
| |
| # pylint: disable=protected-access |
| identity_op.op._set_attr( |
| OpHint.FUNCTION_NAME_ATTR, |
| _attr_value_pb2.AttrValue( |
| s=_compat.as_bytes(self._function_name))) |
| identity_op.op._set_attr( |
| OpHint.FUNCTION_UUID_ATTR, |
| _attr_value_pb2.AttrValue( |
| s=_compat.as_bytes(self._unique_function_id))) |
| identity_op.op._set_attr( |
| self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) |
| identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, |
| _attr_value_pb2.AttrValue(i=self._level)) |
| if self._children_inputs_mappings: |
| identity_op.op._set_attr( |
| OpHint.CHILDREN_INPUTS_MAPPINGS, |
| _attr_value_pb2.AttrValue( |
| s=_compat.as_bytes(_json.dumps( |
| self._children_inputs_mappings)))) |
| |
| if sort_index is not None: |
| identity_op.op._set_attr( |
| OpHint.FUNCTION_SORT_INDEX_ATTR, |
| _attr_value_pb2.AttrValue(i=sort_index)) |
| if aggregate is not None: |
| identity_op.op._set_attr( |
| OpHint.FUNCTION_AGGREGATE_ATTR, |
| _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) |
| # pylint: enable=protected-access |
| return identity_op |
| |
| def __init__(self, |
| function_name, |
| level=1, |
| children_inputs_mappings=None, |
| **kwargs): |
| """Create a OpHint. |
| |
| Args: |
| function_name: Name of the function (the custom op name in tflite) |
| level: OpHint level. |
| children_inputs_mappings: Children OpHint inputs/outputs mapping. |
| children_inputs_mappings should like below: |
| "parent_first_child_input": |
| [{"parent_input_index": num, "child_input_index": num}, ...] |
| "parent_last_child_output": |
| [{"parent_output_index": num, "child_output_index": num}, ...] |
| "internal_children_input_output": |
| [{"child_input_index": num, "child_output_index": num}, ...] |
| **kwargs: Keyword arguments of any constant attributes for the function. |
| """ |
| self._function_name = function_name |
| self._level = level |
| if self._level == 1: |
| assert children_inputs_mappings is None |
| else: |
| assert isinstance(children_inputs_mappings, dict) |
| self._children_inputs_mappings = children_inputs_mappings |
| if self._children_inputs_mappings is not None: |
| self._validate_children_inputs_mappings(self._children_inputs_mappings) |
| self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? |
| self._attrs_to_store_later = kwargs |
| self._stored_attrs = False |
| self._inputs = OpHint.OpHintArgumentTracker( |
| self._function_name, self._unique_function_id, "InputHint", |
| OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) |
| self._outputs = OpHint.OpHintArgumentTracker( |
| self._function_name, self._unique_function_id, "OutputHint", |
| OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, |
| self._children_inputs_mappings) |
| |
| def _validate_children_inputs_mappings(self, children_inputs_mappings): |
| """Validate children inputs mappings is in the right format. |
| |
| Args: |
| children_inputs_mappings: the Children ophint inputs/outputs mapping. |
| """ |
| assert isinstance(children_inputs_mappings, dict) |
| assert "parent_first_child_input" in children_inputs_mappings |
| assert "parent_last_child_output" in children_inputs_mappings |
| assert "internal_children_input_output" in children_inputs_mappings |
| |
| # validate parent_first_child_input. |
| |
| def assert_dictlist_has_keys(dictlist, keys): |
| for dikt in dictlist: |
| assert isinstance(dikt, dict) |
| for key in keys: |
| assert key in dikt |
| |
| assert_dictlist_has_keys( |
| children_inputs_mappings["parent_first_child_input"], |
| ["parent_ophint_input_index", "first_child_ophint_input_index"]) |
| assert_dictlist_has_keys( |
| children_inputs_mappings["parent_last_child_output"], |
| ["parent_output_index", "child_output_index"]) |
| assert_dictlist_has_keys( |
| children_inputs_mappings["internal_children_input_output"], |
| ["child_input_index", "child_output_index"]) |
| |
| def _setattr(self, dest_op, name, value): |
| tensor_value = _ops.convert_to_tensor(value) |
| # pylint: disable=protected-access |
| dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( |
| tensor=tensor_value.op.node_def.attr["value"].tensor)) |
| # pylint: enable=protected-access |
| |
| def add_input(self, *args, **kwargs): |
| """Add a wrapped input argument to the hint. |
| |
| Args: |
| *args: The input tensor. |
| **kwargs: |
| "name" label |
| "tag" a tag to group multiple arguments that will be aggregated. I.e. |
| a string like 'cool_input'. Basically multiple inputs can be added |
| to the same hint for parallel operations that will eventually be |
| combined. An example would be static_rnn which creates multiple copies |
| of state or inputs. |
| "aggregate" aggregation strategy that is valid only for tag non None. |
| Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, |
| and OpHint.AGGREGATE_STACK. |
| "index_override" The global index to use. This corresponds to the |
| argument order in the final stub that will be generated. |
| Returns: |
| The wrapped input tensor. |
| """ |
| return self._inputs.add(*args, **kwargs) |
| |
| def add_output(self, *args, **kwargs): |
| """Add a wrapped output argument to the hint. |
| |
| Args: |
| *args: The output tensor. |
| **kwargs: |
| "name" label |
| "tag" a tag to group multiple arguments that will be aggregated. I.e. |
| a string like 'cool_input'. Basically multiple inputs can be added |
| to the same hint for parallel operations that will eventually be |
| combined. An example would be static_rnn which creates multiple copies |
| of state or inputs. |
| "aggregate" aggregation strategy that is valid only for tag non None. |
| Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, |
| and OpHint.AGGREGATE_STACK. |
| "index_override" The global index to use. This corresponds to the |
| argument order in the final stub that will be generated. |
| Returns: |
| The wrapped output tensor. |
| """ |
| return self._outputs.add(*args, **kwargs) |
| |
| def add_inputs(self, *args, **kwargs): |
| """Add a sequence of inputs to the function invocation. |
| |
| Args: |
| *args: List of inputs to be converted (should be Tf.Tensor). |
| **kwargs: This allows 'names' which should be a list of names. |
| |
| Returns: |
| Wrapped inputs (identity standins that have additional metadata). These |
| are also are also tf.Tensor's. |
| """ |
| if "names" in kwargs: |
| return [ |
| self._inputs.add(arg, name=name) |
| for arg, name in zip(args, kwargs["names"]) |
| ] |
| else: |
| return [self._inputs.add(arg) for arg in args] |
| |
| def add_outputs(self, *args, **kwargs): |
| """Add a sequence of outputs to the function invocation. |
| |
| Args: |
| *args: List of outputs to be converted (should be tf.Tensor). |
| **kwargs: See |
| |
| Returns: |
| Wrapped outputs (identity standins that have additional metadata). These |
| are also tf.Tensor's. |
| """ |
| if "names" in kwargs: |
| return [ |
| self._outputs.add(arg, name=name) |
| for arg, name in zip(args, kwargs["names"]) |
| ] |
| else: |
| return [self._outputs.add(arg) for arg in args] |
| |
| |
| class _LiteOperand(object): |
| """Abstract operand for a tflite hint function._dynamic_rnn_loop. |
| |
| This is a base class that handles representing arguments to an OpHint. |
| It also is able to serialize operands to the stubbed graph_def. |
| Child classes are responsible for being able to |
| store information about the hint identity operators. They are also responsible |
| for knowing how to serialize to output graphdefs. |
| |
| Typically this will be implemented by holding one or more identity nodes |
| that were previously discovered as hints. |
| """ |
| |
| def aggregate_and_return_name_for_input(self, out_graphdef): |
| """This adds the node(s) to out_graphdef and returns the input node name. |
| |
| Args: |
| out_graphdef: A graphdef that is ready to have this input added. |
| |
| Returns: |
| The output that the stub should use as an input for this operand. |
| |
| Raises: |
| RuntimeError: if the method is not implemented. |
| """ |
| del out_graphdef |
| raise RuntimeError("Unimplemented abstract method.") |
| |
| def aggregate_and_return_name_for_output(self, fused_op_name, output_index, |
| out_graphdef): |
| """Add node(s) to graph representing output operands and returns type. |
| |
| Args: |
| fused_op_name: name of the fused op stub name. |
| output_index: Output index that we are currently processing from stub. |
| out_graphdef: The destination graphdef we are currently building up. |
| |
| Returns: |
| The datatype of this identity. |
| |
| Raises: |
| RuntimeError: if the method is not implemented. |
| """ |
| del fused_op_name, output_index, out_graphdef |
| raise RuntimeError("Unimplemented abstract method.") |
| |
| |
| class _LiteSingleOperand(_LiteOperand): |
| """A simple operand that is non-aggregated (i.e. most hints).""" |
| |
| def __init__(self, node): |
| _LiteOperand.__init__(self) |
| self.node = node |
| self.name = _tensor_name_base(node.name) |
| |
| def flatten(self): |
| return [self.name] |
| |
| def aggregate_and_return_name_for_input(self, out_graphdef): |
| return self.name |
| |
| def aggregate_and_return_name_for_output(self, fused_op_name, index, |
| out_graphdef): |
| output_node = _copy.deepcopy(self.node) |
| del output_node.input[:] |
| output_node.input.append(_tensorflow_output_name(fused_op_name, index)) |
| out_graphdef.node.extend([output_node]) |
| return self.node.attr["type"].i |
| |
| def __str__(self): |
| return str(self.name) |
| |
| |
| class _LiteAggregateOperand(_LiteOperand): |
| """An operand for a tflite hint function that is aggregated from many. |
| |
| For example, an LSTM is a grid of operators that are all related. Inputs |
| going into them may need to be fused, so they should all be tracked as |
| related arguments. |
| """ |
| |
| def __init__(self, aggregation): |
| _LiteOperand.__init__(self) |
| self.aggregation = aggregation |
| self.names = {} |
| self.nodes = {} |
| self.flattened = None |
| |
| def add(self, sort, node): |
| self.names[sort] = _tensor_name_base(node.name) |
| self.nodes[sort] = node |
| |
| def flatten_nodes(self): |
| """Return a list of all the node protos in aggregation sorted order.""" |
| if not self.flattened: |
| self.flattened = [None] * len(self.nodes) |
| for idx, node in _six.iteritems(self.nodes): |
| self.flattened[idx] = node |
| for n in self.nodes: |
| if n is None: |
| raise RuntimeError("Aggregate was missing argument.") |
| if self.aggregation == OpHint.AGGREGATE_FIRST: |
| self.flattened = self.flattened[:1] |
| elif self.aggregation == OpHint.AGGREGATE_LAST: |
| self.flattened = self.flattened[-1:] |
| elif self.aggregation == OpHint.AGGREGATE_STACK: |
| pass |
| else: |
| raise ValueError("Invalid aggregation type %r specified" % |
| self.aggregation) |
| return self.flattened |
| |
| def flatten(self): |
| """Return a list of all node names in aggregation sorted sorter.""" |
| return [_tensor_name_base(x.name) for x in self.flatten_nodes()] |
| |
| def aggregate_and_return_name_for_input(self, out_graphdef): |
| """This adds the nodes to out_graphdef and returns an aggregated output. |
| |
| In particular, if you have 4 inputs to a hint stub, this will be the |
| node that you can use as an output. I.e. you have 4 timesteps from a |
| static rnn, then a fused UnidirectionalLSTM will expect 1 input with |
| all 4 time steps. So here we make a pack and return the output name of |
| that pack. |
| |
| Args: |
| out_graphdef: A graphdef that is ready to have this input added. |
| |
| Returns: |
| The name of a pack that aggregates this node. |
| """ |
| flattened = self.flatten_nodes() |
| if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( |
| self.aggregation == OpHint.AGGREGATE_LAST): |
| assert len(flattened) == 1 |
| if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: |
| return _tensor_name_base(flattened[0].name) |
| else: |
| new_node = _node_def_pb2.NodeDef() |
| new_node.op = "Pack" |
| new_node.name = "OpHintStack-%s" % flattened[0].name |
| new_node.attr["N"].i = len(flattened) |
| new_node.attr["T"].type = flattened[0].attr["T"].type |
| for discrete in flattened: |
| new_node.input.append(_tensor_name_base(discrete.name)) |
| out_graphdef.node.extend([new_node]) |
| return new_node.name |
| |
| def aggregate_and_return_name_for_output(self, fused_op_name, output_index, |
| out_graphdef): |
| """This adds to `out_graphdef` all the unaggregated outputs. |
| |
| I.e. we are outputting from a fused stub, but we need to make it compatible |
| with the unfused original graph so we insert an unpack. Ideally in a later |
| stage the unpack -> pack sequences will be removed. |
| |
| Args: |
| fused_op_name: The name of the stub we are in the process of fusing. |
| output_index: The output output_index this object represents. |
| out_graphdef: The graphdef we are in the process of buildings |
| |
| Returns: |
| The type of the aggregated output (so we can finish building the stub |
| op). |
| """ |
| flattened = self.flatten_nodes() |
| if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( |
| self.aggregation == OpHint.AGGREGATE_LAST): |
| assert len(flattened) == 1 |
| if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: |
| temp_op = _LiteSingleOperand(flattened[0]) |
| return temp_op.aggregate_and_return_name_for_output( |
| fused_op_name, output_index, out_graphdef) |
| else: |
| stack_node = _node_def_pb2.NodeDef() |
| stack_node.op = "Unpack" |
| stack_node.name = "OpHintUnstack-%s" % flattened[0].name |
| stack_node.attr["num"].i = len(flattened) |
| output_type = flattened[0].attr["T"].type |
| stack_node.attr["T"].type = output_type |
| stack_node.input.append( |
| _tensorflow_output_name(fused_op_name, output_index)) |
| out_graphdef.node.extend([stack_node]) |
| |
| for idx, discrete in enumerate(flattened): |
| output_node = _copy.deepcopy(discrete) |
| del output_node.input[:] |
| output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) |
| out_graphdef.node.extend([output_node]) |
| |
| return output_type |
| |
| def __str__(self): |
| s = "\t\t\tAGGREGATE %s\n" % self.aggregation |
| for sort, val in self.names.iteritems(): |
| s += "\t\t\t%d: %s\n" % (sort, val) |
| return s |
| |
| |
| class _LiteFuncCall(object): |
| """Represent a TensorFlow Lite custom function. |
| |
| This is uses to accumulate found hints in the graphdef into a single |
| conceptual unit. |
| |
| Attributes: |
| inputs: inputs to the op (hash from index # to argument) |
| outputs: outputs to the op (hash from index # to argument) |
| function_name: the tflite custom op name to use |
| uuid: a unique call id for this particular call (i.e. multiple function |
| calls would have the same function_name but different uuids. |
| params: A param name to key value for op constant data. I.e. for axis on a |
| reduction, strides on a convolution, etc. |
| level: Level of the OpHint. |
| children_inputs_mappings: If the Ophint has children, children inputs |
| mappings indicate how their inputs & outputs are mapped. |
| """ |
| |
| def __init__(self): |
| self.inputs = {} |
| self.outputs = {} |
| self.function_name = None |
| self.uuid = None |
| self.params = {} |
| self.level = -1 |
| self.children_inputs_mappings = {} |
| |
| def flattened_inputs_and_outputs(self): |
| """Return a list of inputs and outputs in a flattened format. |
| |
| Returns: |
| Tuple of (inputs, outputs). where input and output i a list of names. |
| """ |
| |
| def _flatten(input_or_output_dict): |
| flattened_items = [] |
| for item in input_or_output_dict.values(): |
| flattened_items.extend(item.flatten()) |
| return flattened_items |
| |
| return _flatten(self.inputs), _flatten(self.outputs) |
| |
| def __str__(self): |
| |
| def format_args(items): |
| s = "" |
| for idx, item in items.iteritems(): |
| s += ("\t\t%d:\n" % idx) + str(item) |
| return s |
| |
| inputs_str = "\tInputs\n" + format_args(self.inputs) |
| outputs_str = "\tOutputs\n" + format_args(self.outputs) |
| |
| return ( |
| "tflite function %s call %s level %d " |
| "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % |
| (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) |
| |
| |
| def _find_all_hints_in_nodes(nodes): |
| """Look at the all the input nodes and return a list of LiteFuncCall objs. |
| |
| Args: |
| nodes: A TensorFlow graph_def to look for LiteFuncCalls. |
| |
| Returns: |
| a list of `LifeFuncCall` objects in the form |
| |
| """ |
| func_calls = _collections.defaultdict(_LiteFuncCall) |
| |
| for node in nodes: |
| attr = node.attr |
| # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip |
| if (OpHint.FUNCTION_UUID_ATTR not in attr or |
| not attr[OpHint.FUNCTION_UUID_ATTR].s): |
| continue |
| uuid = attr[OpHint.FUNCTION_UUID_ATTR].s |
| |
| # Start building function |
| call_def = func_calls[uuid] |
| call_def.uuid = uuid |
| call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s |
| call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i |
| # Get sorting and aggregation information |
| |
| sort = ( |
| attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i |
| if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) |
| if sort == -1: |
| sort = None |
| aggregation = None |
| if OpHint.FUNCTION_AGGREGATE_ATTR in attr: |
| aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) |
| |
| if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: |
| call_def.children_inputs_mappings = _json.loads( |
| _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) |
| |
| # Add the input or output |
| def put_operand(stuff, index, sort, operand, aggregation): |
| """Add a given index into the function structure.""" |
| if sort is None: |
| stuff[index] = _LiteSingleOperand(operand) |
| else: |
| if index not in stuff: |
| stuff[index] = _LiteAggregateOperand(aggregation) |
| stuff[index].add(sort, operand) |
| |
| if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: |
| put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, |
| sort, node, aggregation) |
| if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: |
| put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, |
| sort, node, aggregation) |
| |
| # Remember attributes |
| for a in attr: |
| if a.startswith("_tflite_attr_"): |
| call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor |
| |
| return func_calls |
| |
| |
| def _extract_topology_sequence_mapping(nodes): |
| return dict( |
| (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) |
| |
| |
| def _find_children_hints_in_while_loop(function_def, nodes_mapping): |
| """Find children hints and all nodes inside the while loop. |
| |
| Args: |
| function_def: Function def of the while loop. |
| nodes_mapping: While loop input_arg : real node name. |
| |
| Returns: |
| Ordered children hints and all re-mapped nodes inside the while loop. |
| """ |
| new_nodes = [] |
| |
| # Make nodes inside function def inputs point to the real nodes. |
| for node in function_def.node_def: |
| for i, _ in enumerate(node.input): |
| if node.input[i] in nodes_mapping: |
| node.input[i] = nodes_mapping[node.input[i]] |
| new_nodes.append(_copy.deepcopy(node)) |
| name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) |
| children_hints = _find_all_hints_in_nodes(new_nodes) |
| children_hints_q = [] |
| # Ordered by the outputs. |
| for hint in _six.itervalues(children_hints): |
| _, output_names = hint.flattened_inputs_and_outputs() |
| seq = name_to_seq_num[output_names[0]] |
| for output_name in output_names: |
| seq = min(seq, name_to_seq_num[output_name]) |
| children_hints_q.append((seq, hint)) |
| children_hints_q.sort(key=lambda tup: tup[0]) |
| ordered_children_hints = [x[1] for x in children_hints_q] |
| return ordered_children_hints, new_nodes |
| |
| |
| def _find_children_hints(call, graph_def): |
| """Find all children hints. |
| |
| For a given OpHint, we find all children hints inside it, we also copy all the |
| nodes inside function defs (if applicable) to the original graph_def, they are |
| returned in a list as well. |
| |
| Args: |
| call: Parent OpHint that contains children ophints. |
| graph_def: Original graph def. |
| |
| Returns: |
| Ordered children hints inside the parent ophint; new graph def that contains |
| nodes inside function defs (if applicable); nodes inside function defs. |
| """ |
| name_to_input_name, _, _ = _extract_graph_summary(graph_def) |
| input_names, output_names = call.flattened_inputs_and_outputs() |
| |
| reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) |
| reachable_by_output = _bfs_for_reachable_nodes(output_names, |
| name_to_input_name) |
| output_nodes_set = set(output_names) |
| children_hints = [] |
| out = _graph_pb2.GraphDef() |
| out.library.CopyFrom(graph_def.library) |
| out.versions.CopyFrom(graph_def.versions) |
| function_def_nodes = set() |
| for node in graph_def.node: |
| out.node.extend([_copy.deepcopy(node)]) |
| n = _tensor_name_base(node.name) |
| if n in reachable_by_output: |
| if n not in reachable_by_input and n not in output_nodes_set: |
| # special handle for while loop function def. |
| if node.op == "While" or node.op == "StatelessWhile": |
| body_name = node.attr["body"].func.name |
| inputs_outside_loop = node.input |
| for function_def in graph_def.library.function: |
| if function_def.signature.name == body_name: |
| function_inputs = function_def.signature.input_arg |
| assert len(inputs_outside_loop) == len(function_inputs) |
| nodes_mapping = {} |
| for i, function_input in enumerate(function_inputs): |
| nodes_mapping[function_input.name] = inputs_outside_loop[i] |
| # TODO(b/123050804): Consider use grappler. |
| (children_hints_in_loop, |
| new_nodes) = _find_children_hints_in_while_loop( |
| function_def, nodes_mapping) |
| function_def_nodes.update([x.name for x in new_nodes]) |
| children_hints.extend(children_hints_in_loop) |
| out.node.extend(new_nodes) |
| |
| return children_hints, out, function_def_nodes |
| |
| |
| def _tensor_name_base(full_tensor_name): |
| """Removes the device assignment code from a tensor. |
| |
| e.g. _tensor_name_base("foo:3") => "foo" |
| |
| Args: |
| full_tensor_name: A tensor name that is annotated with a device placement |
| (this is what tensor flow introspection gives). |
| |
| Returns: |
| A name without any device assignment. |
| """ |
| if full_tensor_name.startswith("^"): |
| return full_tensor_name[1:] |
| return full_tensor_name.split(":")[0] |
| |
| |
| def _tensorflow_output_name(tensor_name, output_index): |
| return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, |
| output_index) |
| |
| |
| # TODO(aselle): This should be converted to grappler in the future. |
| def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, |
| name_to_input_name): |
| """Checks to make sure node only connects to predecessor graph through inputs. |
| |
| Args: |
| n: Node to check |
| reachable_by_input: Nodes that are reachable by all inputs of subgraph |
| input_nodes_set: The set of nodes that are "inputs". |
| name_to_input_name: Maps from name to the list of inputs. |
| |
| Raises: |
| TypeError: If the given node uses items past inputs directly. |
| """ |
| next_to_visit = [n] |
| visited = set() |
| while next_to_visit: |
| current_node = next_to_visit.pop() |
| visited.add(current_node) |
| if (current_node in reachable_by_input and |
| current_node not in input_nodes_set): |
| raise TypeError("Node %s uses input %s not in input_nodes." % |
| (n, current_node)) |
| if current_node not in input_nodes_set: |
| next_to_visit += [ |
| input_node for input_node in name_to_input_name[current_node] |
| if input_node not in visited |
| ] |
| |
| |
| # TODO(aselle): This should be converted to grappler in the future. |
| def _convert_single_op_hint_to_stub(call, |
| graph_def, |
| function_def_nodes=None, |
| is_last_run=True): |
| """Given a graph_def, converts `call` into a stub and returns a new graph_def. |
| |
| Args: |
| call: A single function call to be converted. |
| graph_def: A graph_def to use as input (that has call obviously). |
| function_def_nodes: Nodes inside the function def those are not connected to |
| the graph. |
| is_last_run: Whether it is the last run for a given pass (for OpHint has |
| children). |
| |
| Returns: |
| A new transformed graph-def that has call as a stub (single op). |
| |
| Note: after this process, the graph_def can no longer be loaded into |
| the tensorflow runtime, so all future manipulations are done in graph_def |
| level. |
| """ |
| if function_def_nodes is None: |
| function_def_nodes = set() |
| name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( |
| graph_def) |
| input_names, output_names = call.flattened_inputs_and_outputs() |
| |
| reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) |
| reachable_by_output = _bfs_for_reachable_nodes(output_names, |
| name_to_input_name) |
| output_nodes_set = set(output_names) |
| nodes_after_fuse = [] |
| nodes_deleted_by_fuse = set() |
| # Classify each node. We want to keep everything reachable by input, but |
| # we don't know if things that are not reachable by output or input (things |
| # after fusing). |
| for node in graph_def.node: |
| n = _tensor_name_base(node.name) |
| if n in reachable_by_output: |
| if n not in reachable_by_input and n not in output_nodes_set: |
| nodes_deleted_by_fuse.add(n) |
| elif n not in reachable_by_input and n not in function_def_nodes: |
| # n is a node that after all the fusings, so keep it. |
| nodes_after_fuse.append(n) |
| else: |
| # In the last run, n is a node that is randomly in the graph but not |
| # connected to the chain of dependencies, we will delete n, otherwise |
| # we keep them. |
| if not is_last_run: |
| nodes_after_fuse.append(n) |
| |
| # Make a new graphdef with all the pre-input and input nodes |
| out = _graph_pb2.GraphDef() |
| reachable_by_input_sorted = sorted( |
| list(reachable_by_input), key=lambda n: name_to_seq_num[n]) |
| for node in reachable_by_input_sorted: |
| out.node.extend([_copy.deepcopy(name_to_node[node])]) |
| |
| # Create any stacks to aggregate arguments into to a single input |
| # i.e. for static_rnn's. |
| # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 |
| sorted_input_indices = list(call.inputs.keys()) |
| sorted_input_indices.sort() |
| sorted_output_indices = list(call.outputs.keys()) |
| sorted_output_indices.sort() |
| new_node = _node_def_pb2.NodeDef() |
| # Delegate to each operand to produce the proper new input for this stub node. |
| # In particular, an aggregate input will now be a Pack of some previously |
| # non-fused things. |
| |
| optional_input_node = _node_def_pb2.NodeDef() |
| optional_input_node.name = "Const" + str(_uuid.uuid1().hex) |
| optional_input_node.op = "Const" |
| optional_input_node.attr["dtype"].CopyFrom( |
| _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum)) |
| optional_input_node.attr["value"].CopyFrom( |
| _attr_value_pb2.AttrValue( |
| tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1]))) |
| out.node.extend([optional_input_node]) |
| |
| max_index = max(sorted_input_indices) + 1 |
| for cur_index in range(max_index): |
| if cur_index in sorted_input_indices: |
| inputs = call.inputs[cur_index] |
| input_name = inputs.aggregate_and_return_name_for_input(out) |
| new_node.input.append(input_name) |
| else: |
| new_node.input.append(optional_input_node.name) |
| |
| new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) |
| |
| # Create the function |
| new_node.op = call.function_name |
| new_node.name = call.uuid |
| out.node.extend([new_node]) |
| |
| # Now call each output argument to give them a chance to make the proper |
| # output type and add it to our new_node. |
| output_dtypes = [] |
| max_output_index = max(sorted_output_indices) + 1 |
| for cur_index in range(max_output_index): |
| if cur_index in sorted_output_indices: |
| output = call.outputs[cur_index] |
| output_dtype = ( |
| output.aggregate_and_return_name_for_output(new_node.name, cur_index, |
| out)) |
| else: |
| output_dtype = optional_input_node.attr["type"].i |
| output_dtypes.append(output_dtype) |
| new_node.attr["_output_types"].list.type[:] = output_dtypes |
| # TODO(aselle): what is right here? |
| new_node.attr["_output_quantized"].b = False |
| |
| # Add post output nodes that do not depend on the outputs |
| for n in nodes_after_fuse: |
| should_keep = True |
| for input_name in name_to_input_name[n]: |
| if input_name in nodes_deleted_by_fuse: |
| should_keep = False |
| if should_keep: |
| out.node.extend([_copy.deepcopy(name_to_node[n])]) |
| |
| # Misc. graph_def data that needs copying. |
| out.library.CopyFrom(graph_def.library) |
| out.versions.CopyFrom(graph_def.versions) |
| |
| return out |
| |
| |
| # TODO(aselle): This should be converted to grappler in the future. |
| def _remove_one_redundant_stack_unstack(in_graph_def): |
| """Removes a stack->unstack pattern from in_graph_def in a returned graph. |
| |
| Args: |
| in_graph_def: Graph def to use as input. |
| |
| Returns: |
| Simplified tuple (graph_def, changed_something) where changed_something |
| is true if anything was done. |
| """ |
| name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( |
| in_graph_def) |
| del name_to_seq_num |
| |
| # TODO(aselle): Make this not hardcoded. |
| do_generic_pack_unpack = True |
| |
| out = _graph_pb2.GraphDef() |
| out.library.CopyFrom(in_graph_def.library) |
| out.versions.CopyFrom(in_graph_def.versions) |
| for n in in_graph_def.node: |
| node_name = _tensor_name_base(n.name) |
| if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): |
| continue |
| next_to_visit = [node_name] |
| visited = set() |
| |
| unpack_nodes = set() |
| pack_node = node_name |
| |
| # Find a pattern of unstack connected to a stack (with identities |
| # in between. |
| matches_pattern = True |
| is_hint_created_stack = False |
| while next_to_visit: |
| current_node_name = next_to_visit[0] |
| visited.add(current_node_name) |
| del next_to_visit[0] |
| node = name_to_node[current_node_name] |
| is_op_hint_stack = node.name.startswith("OpHintStack") |
| is_op_hint_unstack = node.name.startswith("OpHintUnstack") |
| if (node.op == "Identity" or is_op_hint_stack or |
| (do_generic_pack_unpack and node.op == "Pack")): |
| is_hint_created_stack |= is_op_hint_stack |
| next_to_visit += [ |
| input_node for input_node in name_to_input_name[current_node_name] |
| if input_node not in visited |
| ] |
| elif (is_op_hint_unstack or |
| (do_generic_pack_unpack and node.op == "Unpack")): |
| unpack_nodes.add(node.name) |
| is_hint_created_stack &= is_op_hint_unstack |
| else: |
| matches_pattern = False |
| break |
| visited.add(node.name) |
| |
| if matches_pattern and len(unpack_nodes) == 1: |
| pack_node = node_name |
| |
| # Check to see if anyone depends on the intermediate identity or the |
| # Unstacked form |
| no_external_dependency = True |
| for other_n in in_graph_def.node: |
| if other_n.name in visited: |
| continue |
| for input_tensor in name_to_input_name[other_n.name]: |
| input_op = _tensor_name_base(input_tensor) |
| if input_op in visited and input_op != pack_node: |
| no_external_dependency = False |
| # Proceed with the substitution if the stack/unstack pair was created |
| # through hints, or that it was not, but nobody is consuming things |
| # between the stack and unstack. |
| if is_hint_created_stack or no_external_dependency: |
| end = unpack_nodes.pop() |
| end_input = name_to_node[end].input[0] |
| # All nodes that depend on the final stack need to be redone to use |
| for other_n in in_graph_def.node: |
| node_name = _tensor_name_base(other_n.name) |
| if node_name not in visited: |
| new_node = _copy.deepcopy(other_n) |
| new_node.input[:] = [ |
| (end_input if stripped == pack_node else non_stripped) |
| for stripped, non_stripped in zip(name_to_input_name[node_name], |
| new_node.input[:]) |
| ] |
| out.node.extend([new_node]) |
| return out, True |
| return in_graph_def, False |
| |
| |
| def _remove_redundant_stack_unstack(graph_def): |
| curr = graph_def |
| del graph_def |
| changed_stuff = True |
| while changed_stuff: |
| curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) |
| return curr |
| |
| |
| def _get_correct_mapping(original_index, nodes): |
| # Special handle for the index is -1 case. |
| # If it is -1, return the last index. |
| if original_index == -1: |
| node_indices = nodes.keys() |
| node_indices = sorted(node_indices) |
| return node_indices[-1] |
| return original_index |
| |
| |
| def _convert_op_hints_to_stubs_helper( |
| graph_def, write_callback=lambda sess, graph_def: None): |
| """Converts a graph_def to a new graph_def where all op hints are stubbed. |
| |
| Args: |
| graph_def: A graph def that we should convert. |
| write_callback: A function pointer that can be used to write intermediate |
| steps of graph transformation (optional). |
| |
| Returns: |
| A new stubbed graph_def. |
| """ |
| hints = _find_all_hints_in_nodes(graph_def.node) |
| |
| hints_q = [] |
| for hint in _six.itervalues(hints): |
| hints_q.append((hint.level, hint.uuid)) |
| |
| hints_q.sort(key=lambda tup: tup[0]) |
| for i in range(len(hints_q) - 1, -1, -1): |
| level, hint_uuid = hints_q[i] |
| |
| curr_graph_def = graph_def |
| del graph_def # prevent using graph_def again (common source of error) |
| for i in range(len(hints_q) - 1, -1, -1): |
| level, hint_uuid = hints_q[i] |
| if level >= 2: |
| children_hints, curr_graph_def, function_def_nodes = _find_children_hints( |
| hints[hint_uuid], curr_graph_def) |
| # pylint: disable=superfluous-parens |
| assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test |
| # pylint: enable=superfluous-parens |
| |
| # Re-wire the children hints inputs/outputs, so latter child's inputs |
| # connect to previous child node's outputs. |
| children_inputs_mappings = hints[hint_uuid].children_inputs_mappings |
| for j, child_hint in enumerate(children_hints): |
| if j == 0: |
| for mapping in children_inputs_mappings["parent_first_child_input"]: |
| parent_input_index = _get_correct_mapping( |
| mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) |
| child_input_index = _get_correct_mapping( |
| mapping["first_child_ophint_input_index"], child_hint.inputs) |
| child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ |
| parent_input_index] |
| else: |
| for mapping in children_inputs_mappings[ |
| "internal_children_input_output"]: |
| input_index = _get_correct_mapping(mapping["child_input_index"], |
| child_hint.inputs) |
| output_index = _get_correct_mapping(mapping["child_output_index"], |
| children_hints[j - 1].outputs) |
| child_hint.inputs[input_index] = children_hints[ |
| j - 1].outputs[output_index] |
| if j == len(children_hints) - 1: |
| for mapping in children_inputs_mappings["parent_last_child_output"]: |
| parent_output_index = _get_correct_mapping( |
| mapping["parent_output_index"], hints[hint_uuid].outputs) |
| child_output_index = _get_correct_mapping( |
| mapping["child_output_index"], child_hint.outputs) |
| child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ |
| parent_output_index] |
| |
| for j, child_hint in enumerate(children_hints): |
| curr_graph_def = _convert_single_op_hint_to_stub( |
| child_hint, curr_graph_def, function_def_nodes, |
| j == len(children_hints) - 1) |
| else: |
| curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], |
| curr_graph_def) |
| write_callback(curr_graph_def, "initial") |
| # The stubbing process can create stacks/unstacks in the case of LSTMs |
| # remove them. |
| curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) |
| return curr_graph_def |
| |
| |
| def find_all_hinted_output_nodes(session=None, graph_def=None): |
| """Find all Ophints output nodes in the graph. |
| |
| This is used to get all the output nodes those are ophinted, it is important |
| for operation like convert_variables_to_constants keep all ophints structure. |
| Note: only one of session or graph_def should be used, not both. |
| Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can |
| generate multiple outputs for unfused subgraph. If not all output nodes are |
| consumed, graph optimization can potentially drop the unused nodes and cause |
| ophints in an invalid states (due to missing ophinted output nodes). So it's |
| important for us to find all those hinted output nodes and make sure they're |
| not discarded away. |
| |
| Args: |
| session: A TensorFlow session that contains the graph to convert. |
| graph_def: A graph def that we should convert. |
| |
| Returns: |
| A list of OpHints output nodes. |
| Raises: |
| ValueError: If both session and graph_def are provided. |
| """ |
| if session is not None and graph_def is not None: |
| raise ValueError("Provide only one of session and graph_def.") |
| hinted_outputs_nodes = [] |
| if session is not None: |
| hints = _find_all_hints_in_nodes(session.graph_def.node) |
| elif graph_def is not None: |
| hints = _find_all_hints_in_nodes(graph_def.node) |
| for hint in _six.itervalues(hints): |
| _, output_nodes = hint.flattened_inputs_and_outputs() |
| hinted_outputs_nodes.extend(output_nodes) |
| return hinted_outputs_nodes |
| |
| |
| def is_ophint_converted(graph_def): |
| if graph_def is None: |
| raise ValueError("Must provide the graph_def.") |
| ophint_converted = False |
| for node in graph_def.node: |
| attr = node.attr |
| if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: |
| ophint_converted = True |
| break |
| return ophint_converted |
| |
| |
| @_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"]) |
| def convert_op_hints_to_stubs(session=None, |
| graph_def=None, |
| write_callback=lambda graph_def, comments: None): |
| """Converts a graphdef with LiteOp hints into stub operations. |
| |
| This is used to prepare for toco conversion of complex intrinsic usages. |
| Note: only one of session or graph_def should be used, not both. |
| |
| Args: |
| session: A TensorFlow session that contains the graph to convert. |
| graph_def: A graph def that we should convert. |
| write_callback: A function pointer that can be used to write intermediate |
| steps of graph transformation (optional). |
| |
| Returns: |
| A new graphdef with all ops contained in OpHints being replaced by |
| a single op call with the right parameters. |
| Raises: |
| ValueError: If both session and graph_def are provided. |
| """ |
| |
| if session is not None and graph_def is not None: |
| raise ValueError("Provide only one of session and graph_def.") |
| |
| if session is not None: |
| return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) |
| elif graph_def is not None: |
| return _convert_op_hints_to_stubs_helper(graph_def, write_callback) |
| else: |
| raise ValueError("Must specify session or graph_def as input.") |
| |
| |
| _allowed_symbols = [ |
| "OpHint", |
| "convert_op_hints_to_stubs", |
| "convert_op_hints_to_stubs_new", |
| "find_all_hinted_output_nodes", |
| "is_ophint_converted", |
| ] |
| remove_undocumented(__name__, _allowed_symbols) |