| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Helpers to convert variables to constants in TensorFlow 2.0.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.framework import tensor_shape_pb2 |
| from tensorflow.core.framework import variable_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.python.eager import wrap_function |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.grappler import tf_optimizer |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.training.saver import export_meta_graph |
| |
| |
| _CONDITIONAL_OPS = set(["If", "StatelessIf"]) |
| _LOOP_OPS = set(["While", "StatelessWhile"]) |
| _CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS) |
| |
| |
| def disable_lower_using_switch_merge(graph_def): |
| """Set '_lower_using_switch_merge' attributes to False. |
| |
| Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs |
| in each function's graph. |
| |
| Args: |
| graph_def: GraphDef proto. |
| |
| Returns: |
| GraphDef |
| """ |
| output_graph_def = graph_pb2.GraphDef() |
| output_graph_def.CopyFrom(graph_def) |
| |
| def disable_control_flow_lowering(node): |
| if node.op in _CONTROL_FLOW_OPS: |
| node.attr["_lower_using_switch_merge"].b = False |
| |
| for node in output_graph_def.node: |
| disable_control_flow_lowering(node) |
| |
| if output_graph_def.library: |
| for func in output_graph_def.library.function: |
| for node in func.node_def: |
| disable_control_flow_lowering(node) |
| return output_graph_def |
| |
| |
| def _run_inline_graph_optimization(func, lower_control_flow): |
| """Apply function inline optimization to the graph. |
| |
| Returns the GraphDef after Grappler's function inlining optimization is |
| applied. This optimization does not work on models with control flow. |
| |
| Args: |
| func: ConcreteFunction. |
| lower_control_flow: Boolean indicating whether or not to lower control flow |
| ops such as If and While. (default True) |
| |
| Returns: |
| GraphDef |
| """ |
| graph_def = func.graph.as_graph_def() |
| if not lower_control_flow: |
| graph_def = disable_lower_using_switch_merge(graph_def) |
| meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) |
| |
| # Clear the initializer_name for the variables collections, since they are not |
| # needed after saved to saved_model. |
| for name in [ |
| "variables", "model_variables", "trainable_variables", "local_variables" |
| ]: |
| raw_list = [] |
| for raw in meta_graph.collection_def["variables"].bytes_list.value: |
| variable = variable_pb2.VariableDef() |
| variable.ParseFromString(raw) |
| variable.ClearField("initializer_name") |
| raw_list.append(variable.SerializeToString()) |
| meta_graph.collection_def[name].bytes_list.value[:] = raw_list |
| |
| # Add a collection 'train_op' so that Grappler knows the outputs. |
| fetch_collection = meta_graph_pb2.CollectionDef() |
| for array in func.inputs + func.outputs: |
| fetch_collection.node_list.value.append(array.name) |
| meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) |
| |
| # Initialize RewriterConfig with everything disabled except function inlining. |
| config = config_pb2.ConfigProto() |
| rewrite_options = config.graph_options.rewrite_options |
| rewrite_options.min_graph_nodes = -1 # do not skip small graphs |
| rewrite_options.optimizers.append("function") |
| return tf_optimizer.OptimizeGraph(config, meta_graph) |
| |
| |
| def _get_tensor_name(name): |
| """Returns the name of the input tensor. |
| |
| Args: |
| name: str |
| |
| Returns: |
| str |
| """ |
| return name.split(":")[0] |
| |
| |
| def _get_new_function_name(name): |
| """Returns the function name with '_frozen' appended. |
| |
| Args: |
| name: str |
| |
| Returns: |
| str |
| """ |
| return name + "_frozen" |
| |
| |
| def _get_node_defs_list(graph_def): |
| """Returns a list of NodeDefs in the GraphDef. |
| |
| This list consists of all NodeDefs in the main graph as well as all control |
| flow NodeDefs in the functions. |
| |
| The remaining NodeDefs in the functions are not included because the op names |
| are not unique and the variables are handled differently than the main graph. |
| The control flow ops need to be extracted because they are need their |
| attributes to be updated similar to the control flow ops in the main graph. |
| |
| Args: |
| graph_def: GraphDef proto. |
| |
| Returns: |
| [NodeDef] |
| """ |
| node_defs = list(graph_def.node) |
| |
| if graph_def.library: |
| for func in graph_def.library.function: |
| node_defs.extend( |
| [node for node in func.node_def if node.op in _CONTROL_FLOW_OPS]) |
| return node_defs |
| |
| |
| def _get_tensor_data(func): |
| """Gets the tensor data for all Placeholders in the model. |
| |
| Returns a dictionary that maps the tensor name to a dictionary containing: |
| data: numpy data |
| index: int index in func.graph.captures |
| is_variable: bool indicating whether the tensor is a variable or not |
| |
| Args: |
| func: ConcreteFunction. |
| |
| Returns: |
| Dict |
| """ |
| tensor_data = {} |
| map_index_to_variable = {} |
| for var in func.graph.variables: |
| for idx, captured_input in enumerate(func.captured_inputs): |
| if var.handle is captured_input: # pylint: disable=protected-access |
| map_index_to_variable[idx] = var |
| break |
| |
| # Iterates through all captures which are represented as Placeholders. |
| for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures): |
| tensor_name = _get_tensor_name(name_tensor.name) |
| is_variable = idx in map_index_to_variable |
| if is_variable: |
| data = map_index_to_variable[idx].numpy() |
| else: |
| data = val_tensor.numpy() |
| tensor_data[tensor_name] = { |
| "data": data, |
| "index": idx, |
| "is_variable": is_variable, |
| } |
| return tensor_data |
| |
| |
| def _get_control_flow_function_data(node_defs, tensor_data): |
| """Gets the types and shapes for the parameters to the function. |
| |
| Creates a map from function name to a list of types and a list of shapes that |
| correspond with the function arguments. The data is primarily determined from |
| the corresponding "If" or "While" op. If the argument is a resource variable, |
| then the type is determined from the type of the data contained within the |
| Tensor. The shape data is only determined in the case of the "While" op. |
| |
| `is_also_output_type` is used to identify the "While" bodies that require the |
| output types to be updated at the same time the input types are updated. |
| |
| Args: |
| node_defs: List of NodeDefs. |
| tensor_data: {str name : Tensor}. |
| |
| Returns: |
| {str function name : {"types" : [int representing DataType], |
| "shapes" : [[int] representing TensorShape]], |
| "is_also_output_type" : bool} |
| """ |
| func_data = {} |
| |
| def get_resource_type(node_name): |
| numpy_type = tensor_data[node_name]["data"].dtype |
| return dtypes.as_dtype(numpy_type).as_datatype_enum |
| |
| def get_resource_shape(node_name): |
| return tensor_shape_pb2.TensorShapeProto(dim=[ |
| tensor_shape_pb2.TensorShapeProto.Dim(size=dim) |
| for dim in tensor_data[node_name]["data"].shape |
| ]) |
| |
| def add_value(func_name, arg_types, output_shapes, is_also_output_type): |
| func_data[func_name] = { |
| "types": arg_types, |
| "shapes": output_shapes, |
| "is_also_output_type": is_also_output_type |
| } |
| |
| for node in node_defs: |
| if node.op in _CONDITIONAL_OPS: |
| arg_types = [dtype for dtype in node.attr["Tin"].list.type] |
| |
| for idx in range(len(arg_types)): |
| if arg_types[idx] == dtypes.resource: |
| # Skip first index which represents the condition. |
| arg_types[idx] = get_resource_type(node.input[idx + 1]) |
| |
| add_value(node.attr["then_branch"].func.name, arg_types, None, False) |
| add_value(node.attr["else_branch"].func.name, arg_types, None, False) |
| elif node.op in _LOOP_OPS: |
| arg_types = [dtype for dtype in node.attr["T"].list.type] |
| output_shapes = [shape for shape in node.attr["output_shapes"].list.shape] |
| |
| for idx in range(len(arg_types)): |
| if arg_types[idx] == dtypes.resource: |
| input_name = node.input[idx] |
| arg_types[idx] = get_resource_type(input_name) |
| output_shapes[idx] = get_resource_shape(input_name) |
| |
| add_value(node.attr["body"].func.name, arg_types, output_shapes, True) |
| add_value(node.attr["cond"].func.name, arg_types, output_shapes, False) |
| return func_data |
| |
| |
| def _populate_const_op(output_node, node_name, dtype, data, data_shape): |
| """Creates a Const op. |
| |
| Args: |
| output_node: TensorFlow NodeDef. |
| node_name: str node name. |
| dtype: AttrValue with a populated .type field. |
| data: numpy data value. |
| data_shape: Tuple of integers containing data shape. |
| """ |
| output_node.op = "Const" |
| output_node.name = node_name |
| output_node.attr["dtype"].CopyFrom(dtype) |
| tensor = tensor_util.make_tensor_proto( |
| data, dtype=dtype.type, shape=data_shape) |
| output_node.attr["value"].tensor.CopyFrom(tensor) |
| |
| |
| def _populate_identity_op(output_node, input_node): |
| """Creates an Identity op from a ReadVariable op. |
| |
| Args: |
| output_node: TensorFlow NodeDef. |
| input_node: TensorFlow NodeDef. |
| """ |
| output_node.op = "Identity" |
| output_node.name = input_node.name |
| output_node.input.append(input_node.input[0]) |
| output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) |
| if "_class" in input_node.attr: |
| output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) |
| |
| |
| def _populate_if_op(output_node, input_node, function_data): |
| """Updates the type attributes and function names of If or StatelessIf. |
| |
| Args: |
| output_node: TensorFlow NodeDef. |
| input_node: TensorFlow NodeDef. |
| function_data: Map of function names to the list of types and shapes that |
| correspond with the function arguments. |
| """ |
| output_node.CopyFrom(input_node) |
| then_func = input_node.attr["then_branch"].func.name |
| output_node.attr["then_branch"].func.name = _get_new_function_name(then_func) |
| output_node.attr["else_branch"].func.name = _get_new_function_name( |
| input_node.attr["else_branch"].func.name) |
| output_node.attr["Tin"].list.CopyFrom( |
| attr_value_pb2.AttrValue.ListValue( |
| type=function_data[then_func]["types"])) |
| |
| |
| def _populate_while_op(output_node, input_node, function_data): |
| """Updates the type attributes and function names of While or StatelessWhile. |
| |
| Args: |
| output_node: TensorFlow NodeDef. |
| input_node: TensorFlow NodeDef. |
| function_data: Map of function names to the list of types and shapes that |
| correspond with the function arguments. |
| """ |
| output_node.CopyFrom(input_node) |
| cond_func = input_node.attr["cond"].func.name |
| output_node.attr["cond"].func.name = _get_new_function_name(cond_func) |
| output_node.attr["body"].func.name = _get_new_function_name( |
| input_node.attr["body"].func.name) |
| output_node.attr["T"].list.CopyFrom( |
| attr_value_pb2.AttrValue.ListValue( |
| type=function_data[cond_func]["types"])) |
| output_node.attr["output_shapes"].list.CopyFrom( |
| attr_value_pb2.AttrValue.ListValue( |
| shape=function_data[cond_func]["shapes"])) |
| |
| |
| def _construct_concrete_function(func, output_graph_def, |
| converted_input_indices): |
| """Constructs a concrete function from the `output_graph_def`. |
| |
| Args: |
| func: ConcreteFunction |
| output_graph_def: GraphDef proto. |
| converted_input_indices: Set of integers of input indices that were |
| converted to constants. |
| |
| Returns: |
| ConcreteFunction. |
| """ |
| # Create a ConcreteFunction from the new GraphDef. |
| input_tensors = func.graph.internal_captures |
| converted_inputs = object_identity.ObjectIdentitySet( |
| [input_tensors[index] for index in converted_input_indices]) |
| not_converted_inputs = object_identity.ObjectIdentitySet( |
| func.inputs).difference(converted_inputs) |
| not_converted_inputs_map = { |
| tensor.name: tensor for tensor in not_converted_inputs |
| } |
| |
| new_input_names = [tensor.name for tensor in not_converted_inputs] |
| new_output_names = [tensor.name for tensor in func.outputs] |
| new_func = wrap_function.function_from_graph_def(output_graph_def, |
| new_input_names, |
| new_output_names) |
| |
| # Manually propagate shape for input tensors where the shape is not correctly |
| # propagated. Scalars shapes are lost when wrapping the function. |
| for input_tensor in new_func.inputs: |
| input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) |
| return new_func |
| |
| |
| def convert_variables_to_constants_v2(func, lower_control_flow=True): |
| """Replaces all the variables in a graph with constants of the same values. |
| |
| TensorFlow 2.0 function for converting all Variable ops into Const ops holding |
| the same values. This makes it possible to describe the network fully with a |
| single GraphDef file, and allows the removal of a lot of ops related to |
| loading and saving the variables. This function runs Grappler's function |
| inlining optimization in order to return a single subgraph. |
| |
| The current implementation only works for graphs that do not contain any |
| control flow or embedding related ops. |
| |
| Args: |
| func: ConcreteFunction. |
| lower_control_flow: Boolean indicating whether or not to lower control flow |
| ops such as If and While. (default True) |
| |
| Returns: |
| ConcreteFunction containing a simplified version of the original. |
| """ |
| # Inline the graph in order to remove functions when possible. |
| graph_def = _run_inline_graph_optimization(func, lower_control_flow) |
| |
| # Gets list of all node defs include those in the library. |
| node_defs = _get_node_defs_list(graph_def) |
| |
| # Get mapping from node name to node. |
| name_to_node = {_get_tensor_name(node.name): node for node in node_defs} |
| |
| # Get mapping from node name to variable value. |
| tensor_data = _get_tensor_data(func) |
| |
| # Get mapping from function name to argument types. |
| function_data = _get_control_flow_function_data(node_defs, tensor_data) |
| |
| # Get variable data for all nodes in `node_defs`. |
| reference_variables = {} |
| resource_identities = {} |
| placeholders = {} |
| converted_input_indices = set() |
| |
| def _save_placeholder(node_name, dtype): |
| placeholders[node_name] = { |
| "dtype": dtype, |
| "data": tensor_data[node_name]["data"], |
| } |
| converted_input_indices.add(tensor_data[node_name]["index"]) |
| |
| for node in node_defs: |
| if node.op in _CONDITIONAL_OPS: |
| # Get dtype and data for resource Placeholders. |
| then_func = node.attr["then_branch"].func.name |
| arg_types = function_data[then_func]["types"] |
| for idx, input_tensor in enumerate(node.input[1:]): |
| input_name = _get_tensor_name(input_tensor) |
| if input_name in tensor_data: |
| dtype = attr_value_pb2.AttrValue(type=arg_types[idx]) |
| _save_placeholder(_get_tensor_name(input_tensor), dtype) |
| elif node.op in _LOOP_OPS: |
| # Get dtype and data for resource Placeholders. |
| cond_func = node.attr["cond"].func.name |
| arg_types = function_data[cond_func]["types"] |
| for idx, input_tensor in enumerate(node.input): |
| input_name = _get_tensor_name(input_tensor) |
| if input_name in tensor_data: |
| dtype = attr_value_pb2.AttrValue(type=arg_types[idx]) |
| _save_placeholder(_get_tensor_name(input_tensor), dtype) |
| elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and |
| name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS): |
| # Store the dtype for Identity resource ops that are outputs of While ops. |
| while_node = name_to_node[_get_tensor_name(node.input[0])] |
| body_func = while_node.attr["body"].func.name |
| input_data = node.input[0].split(":") |
| idx = 0 if len(input_data) == 1 else int(input_data[1]) |
| |
| dtype = attr_value_pb2.AttrValue( |
| type=function_data[body_func]["types"][idx]) |
| resource_identities[node.name] = dtype |
| elif node.op == "VariableV2": |
| # Get data for VariableV2 ops (reference variables) that cannot be lifted. |
| with func.graph.as_default(): |
| identity_node = array_ops.identity( |
| func.graph.as_graph_element(node.name + ":0")) |
| reference_variables[node.name] = ( |
| func.prune([], [identity_node.name])()[0]) |
| elif node.name in tensor_data and not tensor_data[node.name]["is_variable"]: |
| # Get dtype and data for non-variable Placeholders (ex. values for 1.X |
| # Const ops that are loaded as Placeholders in 2.0) |
| _save_placeholder(node.name, node.attr["dtype"]) |
| elif node.op in ["ReadVariableOp", "ResourceGather"]: |
| # Get dtype and data for Placeholder ops associated with ReadVariableOp |
| # and ResourceGather ops. There can be an Identity in between the |
| # resource op and Placeholder. Store the dtype for the Identity ops. |
| input_name = _get_tensor_name(node.input[0]) |
| while name_to_node[input_name].op == "Identity": |
| resource_identities[input_name] = node.attr["dtype"] |
| input_name = _get_tensor_name(name_to_node[input_name].input[0]) |
| if name_to_node[input_name].op != "Placeholder": |
| raise ValueError("Cannot find the Placeholder op that is an input " |
| "to the ReadVariableOp.") |
| _save_placeholder(input_name, node.attr["dtype"]) |
| |
| # Reconstruct the graph with constants in place of variables. |
| output_graph_def = graph_pb2.GraphDef() |
| |
| for input_node in graph_def.node: |
| output_node = output_graph_def.node.add() |
| # Convert VariableV2 ops to Const ops. |
| if input_node.name in reference_variables: |
| data = reference_variables[input_node.name] |
| dtype = attr_value_pb2.AttrValue(type=data.dtype.as_datatype_enum) |
| _populate_const_op(output_node, input_node.name, dtype, data.numpy(), |
| data.shape) |
| # Convert Placeholder ops to Const ops. |
| elif input_node.name in placeholders: |
| data = placeholders[input_node.name]["data"] |
| dtype = placeholders[input_node.name]["dtype"] |
| _populate_const_op(output_node, input_node.name, dtype, data, data.shape) |
| # Update the dtype for Identity ops that are inputs to ReadVariableOps. |
| elif input_node.name in resource_identities: |
| output_node.CopyFrom(input_node) |
| output_node.attr["T"].CopyFrom(resource_identities[input_node.name]) |
| # Convert ReadVariableOps to Identity ops. |
| elif input_node.op == "ReadVariableOp": |
| _populate_identity_op(output_node, input_node) |
| # Convert ResourceGather to Gather ops with a Const axis feeding into it. |
| elif input_node.op == "ResourceGather": |
| if input_node.attr["batch_dims"].i != 0: |
| raise ValueError("batch_dims != 0 is not supported by freeze_graph.") |
| output_axis_node = output_graph_def.node.add() |
| axis_node_name = input_node.name + "/axis" |
| axis_dtype = input_node.attr["Tindices"] |
| axis_data = np.array(input_node.attr["batch_dims"].i) |
| _populate_const_op(output_axis_node, axis_node_name, axis_dtype, |
| axis_data, axis_data.shape) |
| |
| output_node.op = "GatherV2" |
| output_node.name = input_node.name |
| output_node.input.extend( |
| [input_node.input[0], input_node.input[1], axis_node_name]) |
| output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"]) |
| output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"]) |
| output_node.attr["Taxis"].CopyFrom(axis_dtype) |
| if "_class" in input_node.attr: |
| output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) |
| # Update the function names and argument types for the conditional ops. |
| elif input_node.op in _CONDITIONAL_OPS: |
| _populate_if_op(output_node, input_node, function_data) |
| elif input_node.op in _LOOP_OPS: |
| _populate_while_op(output_node, input_node, function_data) |
| else: |
| output_node.CopyFrom(input_node) |
| |
| # Add functions to reconstructed graph. |
| if graph_def.library: |
| library = output_graph_def.library |
| |
| for input_library_func in graph_def.library.function: |
| orig_func_name = input_library_func.signature.name |
| new_func_name = _get_new_function_name(orig_func_name) |
| |
| # Do not copy any functions that aren't being used in the graph. Any |
| # functions that are not used by control flow should have been inlined. |
| if orig_func_name not in function_data: |
| continue |
| |
| output_library_func = library.function.add() |
| for key, value in input_library_func.ret.items(): |
| output_library_func.ret[key] = value |
| for key, value in input_library_func.control_ret.items(): |
| output_library_func.control_ret[key] = value |
| |
| # Update the input types in the function signature. Update the output |
| # types for functions that are while loop bodies. |
| output_library_func.signature.CopyFrom(input_library_func.signature) |
| output_library_func.signature.name = new_func_name |
| for dtype, arg in zip(function_data[orig_func_name]["types"], |
| output_library_func.signature.input_arg): |
| arg.type = dtype |
| if function_data[orig_func_name]["is_also_output_type"]: |
| for dtype, arg in zip(function_data[orig_func_name]["types"], |
| output_library_func.signature.output_arg): |
| arg.type = dtype |
| |
| # Update the NodeDefs. |
| func_variables = { |
| node.name: node.input[0] |
| for node in input_library_func.node_def |
| if node.op == "ReadVariableOp" |
| } |
| |
| for input_node in input_library_func.node_def: |
| output_node = output_library_func.node_def.add() |
| # Convert ReadVariableOps to Identity ops. |
| if input_node.op == "ReadVariableOp": |
| _populate_identity_op(output_node, input_node) |
| # Update the function names and argument types for the conditional ops. |
| elif input_node.op in _CONDITIONAL_OPS: |
| _populate_if_op(output_node, input_node, function_data) |
| elif input_node.op in _LOOP_OPS: |
| _populate_while_op(output_node, input_node, function_data) |
| else: |
| output_node.CopyFrom(input_node) |
| # Convert :value to :output for ops that use the ReadVariableOp. |
| for idx, full_name in enumerate(input_node.input): |
| input_name = _get_tensor_name(full_name) |
| if input_name in func_variables: |
| full_name_parts = full_name.split(":") |
| full_name_parts[1] = "output" |
| input_name = ":".join(full_name_parts) |
| output_node.input[idx] = input_name |
| |
| output_graph_def.versions.CopyFrom(graph_def.versions) |
| return _construct_concrete_function(func, output_graph_def, |
| converted_input_indices) |