Fix style in `op_hint.py` to match formatting from Copybara.
No functional changes
PiperOrigin-RevId: 311566454
Change-Id: Ic4f002df42168bdb8841b80a93ebf22a8e7fa4bd
diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py
index 159fcaa..9d62c1b 100644
--- a/tensorflow/lite/python/op_hint.py
+++ b/tensorflow/lite/python/op_hint.py
@@ -435,6 +435,7 @@
Args:
*args: List of inputs to be converted (should be Tf.Tensor).
**kwargs: This allows 'names' which should be a list of names.
+
Returns:
Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's.
@@ -453,6 +454,7 @@
Args:
*args: List of outputs to be converted (should be tf.Tensor).
**kwargs: See
+
Returns:
Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's.
@@ -574,8 +576,8 @@
elif self.aggregation == OpHint.AGGREGATE_STACK:
pass
else:
- raise ValueError(
- "Invalid aggregation type %r specified" % self.aggregation)
+ raise ValueError("Invalid aggregation type %r specified" %
+ self.aggregation)
return self.flattened
def flatten(self):
@@ -646,8 +648,8 @@
stack_node.attr["num"].i = len(flattened)
output_type = flattened[0].attr["T"].type
stack_node.attr["T"].type = output_type
- stack_node.input.append(_tensorflow_output_name(
- fused_op_name, output_index))
+ stack_node.input.append(
+ _tensorflow_output_name(fused_op_name, output_index))
out_graphdef.node.extend([stack_node])
for idx, discrete in enumerate(flattened):
@@ -675,11 +677,10 @@
inputs: inputs to the op (hash from index # to argument)
outputs: outputs to the op (hash from index # to argument)
function_name: the tflite custom op name to use
- uuid: a unique call id for this particular call (i.e.
- multiple function calls would have the same function_name but different
- uuids.
- params: A param name to key value for op constant data. I.e. for
- axis on a reduction, strides on a convolution, etc.
+ uuid: a unique call id for this particular call (i.e. multiple function
+ calls would have the same function_name but different uuids.
+ params: A param name to key value for op constant data. I.e. for axis on a
+ reduction, strides on a convolution, etc.
level: Level of the OpHint.
children_inputs_mappings: If the Ophint has children, children inputs
mappings indicate how their inputs & outputs are mapped.
@@ -700,6 +701,7 @@
Returns:
Tuple of (inputs, outputs). where input and output i a list of names.
"""
+
def _flatten(input_or_output_dict):
flattened_items = []
for item in input_or_output_dict.values():
@@ -709,6 +711,7 @@
return _flatten(self.inputs), _flatten(self.outputs)
def __str__(self):
+
def format_args(items):
s = ""
for idx, item in items.iteritems():
@@ -739,8 +742,8 @@
for node in nodes:
attr = node.attr
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
- if (OpHint.FUNCTION_UUID_ATTR not in attr
- or not attr[OpHint.FUNCTION_UUID_ATTR].s):
+ if (OpHint.FUNCTION_UUID_ATTR not in attr or
+ not attr[OpHint.FUNCTION_UUID_ATTR].s):
continue
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
@@ -751,9 +754,11 @@
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
# Get sorting and aggregation information
- sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
- if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
- if sort == -1: sort = None
+ sort = (
+ attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
+ if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
+ if sort == -1:
+ sort = None
aggregation = None
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
@@ -887,6 +892,7 @@
Args:
full_tensor_name: A tensor name that is annotated with a device placement
(this is what tensor flow introspection gives).
+
Returns:
A name without any device assignment.
"""
@@ -919,10 +925,10 @@
while next_to_visit:
current_node = next_to_visit.pop()
visited.add(current_node)
- if (current_node in reachable_by_input
- and current_node not in input_nodes_set):
- raise TypeError(
- "Node %s uses input %s not in input_nodes." % (n, current_node))
+ if (current_node in reachable_by_input and
+ current_node not in input_nodes_set):
+ raise TypeError("Node %s uses input %s not in input_nodes." %
+ (n, current_node))
if current_node not in input_nodes_set:
next_to_visit += [
input_node for input_node in name_to_input_name[current_node]
@@ -1066,6 +1072,7 @@
Args:
in_graph_def: Graph def to use as input.
+
Returns:
Simplified tuple (graph_def, changed_something) where changed_something
is true if anything was done.
@@ -1101,15 +1108,15 @@
node = name_to_node[current_node_name]
is_op_hint_stack = node.name.startswith("OpHintStack")
is_op_hint_unstack = node.name.startswith("OpHintUnstack")
- if (node.op == "Identity" or is_op_hint_stack
- or (do_generic_pack_unpack and node.op == "Pack")):
+ if (node.op == "Identity" or is_op_hint_stack or
+ (do_generic_pack_unpack and node.op == "Pack")):
is_hint_created_stack |= is_op_hint_stack
next_to_visit += [
input_node for input_node in name_to_input_name[current_node_name]
if input_node not in visited
]
- elif (is_op_hint_unstack
- or (do_generic_pack_unpack and node.op == "Unpack")):
+ elif (is_op_hint_unstack or
+ (do_generic_pack_unpack and node.op == "Unpack")):
unpack_nodes.add(node.name)
is_hint_created_stack &= is_op_hint_unstack
else:
@@ -1124,7 +1131,8 @@
# Unstacked form
no_external_dependency = True
for other_n in in_graph_def.node:
- if other_n.name in visited: continue
+ if other_n.name in visited:
+ continue
for input_tensor in name_to_input_name[other_n.name]:
input_op = _tensor_name_base(input_tensor)
if input_op in visited and input_op != pack_node:
@@ -1141,9 +1149,9 @@
if node_name not in visited:
new_node = _copy.deepcopy(other_n)
new_node.input[:] = [
- (end_input if stripped == pack_node else
- non_stripped) for stripped, non_stripped in zip(
- name_to_input_name[node_name], new_node.input[:])
+ (end_input if stripped == pack_node else non_stripped)
+ for stripped, non_stripped in zip(name_to_input_name[node_name],
+ new_node.input[:])
]
out.node.extend([new_node])
return out, True
@@ -1177,6 +1185,7 @@
graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional).
+
Returns:
A new stubbed graph_def.
"""
@@ -1306,6 +1315,7 @@
graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional).
+
Returns:
A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters.