| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/cc/tools/freeze_saved_model.h" |
| |
| #include <iostream> |
| #include <queue> |
| |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/protobuf/meta_graph.pb.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| // Gets tensor names from tensor_info and inserts them into the set of tensor |
| // names. |
| void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, |
| std::unordered_set<string>* tensor_names) { |
| if (tensor_info.has_coo_sparse()) { |
| // If the tensor is sparse we have to add all three tensors of the sparse |
| // representations. |
| const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse(); |
| tensor_names->insert(coo_sparse.values_tensor_name()); |
| tensor_names->insert(coo_sparse.indices_tensor_name()); |
| tensor_names->insert(coo_sparse.dense_shape_tensor_name()); |
| } else if (tensor_info.has_composite_tensor()) { |
| for (const auto& component : tensor_info.composite_tensor().components()) { |
| tensor_names->insert(component.name()); |
| } |
| } else { |
| tensor_names->insert(tensor_info.name()); |
| } |
| } |
| |
| // Gets the union of all inputs and outputs of all SignatureDefs in the bundle |
| void GetSignatureDefsInputsAndOutputs( |
| const SavedModelBundle& saved_model_bundle, |
| std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) { |
| for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) { |
| const SignatureDef& signature_def = sigdef_elem.second; |
| for (auto& input_elem : signature_def.inputs()) { |
| GetTensorNamesFromTensorInfo(input_elem.second, inputs); |
| } |
| for (auto& output_elem : signature_def.outputs()) { |
| GetTensorNamesFromTensorInfo(output_elem.second, outputs); |
| } |
| } |
| } |
| |
| // Gets a map from string node name to NodeDef. |
| void GetNodeNameToNodeDefMap( |
| GraphDef* graph_def, |
| std::unordered_map<string, NodeDef*>* name_to_node_map) { |
| for (size_t i = 0; i < graph_def->node_size(); i++) { |
| NodeDef* node = graph_def->mutable_node(i); |
| (*name_to_node_map)[node->name()] = node; |
| } |
| } |
| |
| // Strips off the tensor part of the tensor_name to get the node_name. |
| const string GetNodeNameFromTensorName(string tensor_name) { |
| if (tensor_name[0] == '^') { |
| tensor_name.erase(0, 1); |
| } |
| std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':'); |
| return tensor_name_parts[0]; |
| } |
| |
| // Gets the set of node names needed by `outputs` and the corresponding set of |
| // variable nodes to convert. |
| void GetReachableNodesAndVariables( |
| GraphDef* graph_def, const std::unordered_set<string>& outputs, |
| const std::unordered_map<string, NodeDef*>& name_to_node_map, |
| std::unordered_set<string>* reachable_node_names, |
| std::unordered_set<string>* variable_node_names) { |
| // TODO(suharshs): Add support for ResourceVariables. |
| static const std::unordered_set<string>* kVariableTypes = |
| new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"}); |
| |
| std::queue<string> nodes_to_visit; |
| for (const string& output_tensor_name : outputs) { |
| nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); |
| } |
| // We do a traversal backwards from the outputs specified in the MetaGraphDef. |
| while (!nodes_to_visit.empty()) { |
| const string node_name = nodes_to_visit.front(); |
| nodes_to_visit.pop(); |
| if (reachable_node_names->find(node_name) != reachable_node_names->end()) { |
| continue; |
| } |
| reachable_node_names->insert(node_name); |
| NodeDef* node = name_to_node_map.at(node_name); |
| if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { |
| variable_node_names->insert(node->name()); |
| } |
| for (const string& input_tensor_name : node->input()) { |
| nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); |
| } |
| } |
| } |
| |
| // Gets a map from variable name to variable value. |
| Status GetVariableNameToTensorMap( |
| Session* session, |
| const std::unordered_map<string, NodeDef*>& name_to_node_map, |
| std::unordered_set<string> variable_names_set, |
| std::unordered_map<string, Tensor>* variable_name_to_value_map) { |
| if (variable_names_set.empty()) { |
| return Status::OK(); |
| } |
| std::vector<string> variable_names; |
| variable_names.reserve(variable_names_set.size()); |
| std::vector<string> tensor_names; |
| tensor_names.reserve(variable_names_set.size()); |
| for (const string& node_name : variable_names_set) { |
| variable_names.push_back(node_name); |
| NodeDef* node_def = name_to_node_map.at(node_name); |
| if (node_def->op() == "VarHandleOp") { |
| // If this is a resource variable, we have to run the corresponding |
| // ReadVariableOp. |
| tensor_names.push_back(node_name + "/Read/ReadVariableOp:0"); |
| } else { |
| tensor_names.push_back(node_name + ":0"); |
| } |
| } |
| std::vector<Tensor> outputs; |
| TF_RETURN_IF_ERROR( |
| session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs)); |
| for (size_t i = 0; i < variable_names.size(); i++) { |
| (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; |
| } |
| return Status::OK(); |
| } |
| |
| // Converts a Variable NodeDef into a Constant NodeDef. |
| void ConvertVariableToConstant(const NodeDef& variable_node, |
| const Tensor& variable_value, |
| NodeDef* const_node) { |
| const_node->set_name(variable_node.name()); |
| const_node->set_op("Const"); |
| (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype"); |
| variable_value.AsProtoTensorContent( |
| (*const_node->mutable_attr())["value"].mutable_tensor()); |
| } |
| |
| // Converts a ReadVariableOp NodeDef to an Identity NodeDef. |
| void ConvertReadVariableOpToIdentity(const NodeDef& node, |
| NodeDef* identity_node) { |
| identity_node->set_name(node.name()); |
| identity_node->set_op("Identity"); |
| (*identity_node->mutable_attr())["T"] = node.attr().at("dtype"); |
| identity_node->add_input(node.input(0)); |
| } |
| |
| // Freezes the subgraph of all nodes needed by `outputs`. |
| Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, |
| const std::unordered_set<string>& outputs, |
| GraphDef* frozen_graph_def) { |
| GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def(); |
| // Copy versions and library as-is from original graph. |
| *frozen_graph_def->mutable_versions() = graph_def.versions(); |
| *frozen_graph_def->mutable_library() = graph_def.library(); |
| // If the graph is empty there is nothing left to do. |
| if (graph_def.node_size() == 0) { |
| return Status::OK(); |
| } |
| // name_to_node_map is needed to get the inputs from the NodeDef corresponding |
| // the a string node name. These inputs are used when doing our backwards |
| // traversal. |
| std::unordered_map<string, NodeDef*> name_to_node_map; |
| GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map); |
| std::unordered_set<string> reachable_node_names; |
| std::unordered_set<string> variable_node_names; |
| GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map, |
| &reachable_node_names, &variable_node_names); |
| std::unordered_map<string, Tensor> variable_to_value_map; |
| TF_RETURN_IF_ERROR(GetVariableNameToTensorMap( |
| saved_model_bundle.session.get(), name_to_node_map, variable_node_names, |
| &variable_to_value_map)); |
| // We copy the nodes in the same order they were in the original graph_def. |
| for (const NodeDef& node : graph_def.node()) { |
| if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { |
| continue; |
| } |
| if (variable_node_names.find(node.name()) != variable_node_names.end()) { |
| ConvertVariableToConstant(node, variable_to_value_map[node.name()], |
| frozen_graph_def->add_node()); |
| } else if (node.op() == "ReadVariableOp" && |
| variable_node_names.find(node.input(0)) != |
| variable_node_names.end()) { |
| // If the node is a ReadVariableOp, its input VarHandleOp will be |
| // converted to a Constant, so we will need to convert it to an Identity. |
| ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node()); |
| } else { |
| // If the node isn't a variable, just copy the node as-is. |
| *frozen_graph_def->add_node() = node; |
| } |
| } |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, |
| GraphDef* frozen_graph_def, |
| std::unordered_set<string>* inputs, |
| std::unordered_set<string>* outputs) { |
| GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); |
| TF_RETURN_IF_ERROR( |
| FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); |
| return Status::OK(); |
| } |
| |
| } // namespace tensorflow |