| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/grappler/grappler_item_builder.h" |
| |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "tensorflow/core/common_runtime/device.h" |
| #include "tensorflow/core/common_runtime/device_factory.h" |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/graph_optimizer.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/graph_def_util.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/framework/variable.pb.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/grappler/inputs/utils.h" |
| #include "tensorflow/core/grappler/op_types.h" |
| #include "tensorflow/core/grappler/optimizers/model_pruner.h" |
| #include "tensorflow/core/grappler/utils.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| #include "tensorflow/core/lib/io/path.h" |
| #include "tensorflow/core/platform/protobuf_internal.h" |
| #include "tensorflow/core/protobuf/meta_graph.pb.h" |
| #include "tensorflow/core/protobuf/saver.pb.h" |
| #include "tensorflow/core/public/session_options.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| |
| namespace { |
| |
| void InitializeTensor(DataType type, Tensor* tensor) { |
| const int period = 7; |
| if (type == DT_FLOAT) { |
| auto flat = tensor->flat<float>(); |
| // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ... |
| for (int i = 0; i < flat.size(); i++) { |
| flat(i) = static_cast<float>(i % period) / 10.0f; |
| } |
| } else if (type == DT_INT64) { |
| auto flat = tensor->flat<int64>(); |
| // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ... |
| for (int i = 0; i < flat.size(); i++) { |
| flat(i) = i % period; |
| } |
| } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) { |
| // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to |
| // is_simple_type<> in tensorflow/core/framework/type_traits.h, and |
| // Allocator will run non-trivial constructor/destructor for a Tensor with |
| // one of these types, so we should not memset its buffer. |
| memset(const_cast<char*>(tensor->tensor_data().data()), 0, |
| tensor->tensor_data().size()); |
| } |
| } |
| |
| // Applies the same graph pruning logic to the graph as Session.Run in TF. |
| // If the returned status is not OK, item state may be inconsistent. |
| Status PruneGraph(GrapplerItem* item) { |
| ModelPruner pruner; |
| GraphDef pruned_graph; |
| Cluster* cluster = nullptr; // ModelPruner doesn't check cluster. |
| TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph)); |
| item->graph = std::move(pruned_graph); |
| return Status::OK(); |
| } |
| |
| // Replace any unknown dimensions in a shape with |
| // cfg.placeholder_unknown_output_shape_dim if it is no less than 0. |
| Status ReplaceUnknownShapeDim(const ItemConfig& cfg, |
| const TensorShapeProto& shape_pb_in, |
| TensorShapeProto* shape_pb_out, |
| TensorShape* shape_out) { |
| std::vector<int32> dims; |
| for (const auto& dim_proto : shape_pb_in.dim()) { |
| if (cfg.placeholder_unknown_output_shape_dim >= 0 && |
| dim_proto.size() == -1) { |
| dims.push_back(cfg.placeholder_unknown_output_shape_dim); |
| shape_pb_out->add_dim()->set_size( |
| cfg.placeholder_unknown_output_shape_dim); |
| } else { |
| dims.push_back(std::max<int32>(1, dim_proto.size())); |
| shape_pb_out->add_dim()->set_size(dim_proto.size()); |
| } |
| } |
| return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out); |
| } |
| |
| // Replace unknown dimensions in Placeholder shape if |
| // cfg.placeholder_unknown_output_shape_dim is set or |
| // the Placeholder node has _output_shapes. |
| // Otherwise keep it intact to keep compatible with shape annotation |
| // (b/134092018). |
| Status UpdatePlaceholderShape( |
| const ItemConfig& cfg, |
| const std::unordered_set<string>& signature_feed_nodes, |
| GrapplerItem* new_item, NodeDef* node) { |
| if (node->attr().count("dtype") == 0) { |
| return errors::Internal("Unknown type for placeholder ", node->name(), |
| ", skipping this input"); |
| } |
| DataType type = node->attr().at("dtype").type(); |
| |
| // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and |
| // _output_shapes is present case. |
| if (node->attr().count("shape") == 0) { |
| return errors::Internal("Unknown shape for placeholder ", node->name(), |
| ", skipping this input"); |
| } |
| |
| // Replace all unknown dimensions in the placeholder's tensorshape proto |
| // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape |
| // from it. We do this because in newer protos, the input placeholder |
| // shape is not empty if the shape is partially defined. |
| TensorShape shape; |
| TensorShapeProto shape_proto; |
| Status make_shape_status = ReplaceUnknownShapeDim( |
| cfg, node->attr().at("shape").shape(), &shape_proto, &shape); |
| if (!make_shape_status.ok()) { |
| return errors::Internal("Invalid shape for placeholder ", node->name(), |
| ": ", make_shape_status, ", skipping this input"); |
| } |
| |
| // Some placeholder nodes have a mis-match between the node |
| // attribute "shape" and a different node attribute "_output_shapes". |
| // Specifically, a shape with shape.dims() == 0 could indicate either |
| // a scalar or an unknown shape. In those cases, we check _output_shapes |
| // for additional information. |
| // This case is observed in the bnmt graphs. Have not observed any |
| // cases where there was more than 1 _output_shapes, so limit it |
| // to cases where there is only 1 _output_shapes. |
| // We only do this if cfg.placeholder_unknown_output_shape_dim has |
| // been set to avoid crashing non-BNMT graphs. |
| // TODO(andiryxu): Investigate if this is a bug in BNMT graph. |
| if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) && |
| (node->attr().count("_output_shapes") == 1)) { |
| const auto& output_shapes = |
| node->attr().at("_output_shapes").list().shape(0); |
| |
| if (output_shapes.dim_size() != 0) { |
| shape.Clear(); |
| shape_proto.clear_dim(); |
| |
| for (const auto& dim : output_shapes.dim()) { |
| auto size = dim.size(); |
| if (size == -1) size = cfg.placeholder_unknown_output_shape_dim; |
| shape.AddDim(size); |
| shape_proto.add_dim()->set_size(size); |
| } |
| } |
| } |
| |
| Tensor fake_input(type, shape); |
| InitializeTensor(type, &fake_input); |
| |
| if (cfg.feed_nodes.empty()) { |
| // No specific feed nodes were given. Assume all placeholders are fed. |
| if (signature_feed_nodes.count(node->name()) == 0) { |
| new_item->feed.emplace_back(node->name(), fake_input); |
| } |
| } else if (cfg.feed_nodes.count(node->name()) > 0) { |
| // If specific feed nodes were given, only update their tensors. |
| auto it = find_if(new_item->feed.begin(), new_item->feed.end(), |
| [&node](std::pair<string, Tensor>& f) { |
| return f.first == node->name(); |
| }); |
| DCHECK(it != new_item->feed.end()); |
| it->second = fake_input; |
| } |
| |
| // Set the shape of the node in the graph. This is needed for statically |
| // inferring shapes and is a no-op when dynamically inferring shapes as |
| // the Placeholder shape will match the shape passed from new_item->feed. |
| // Only replace node shape with known shape. For unknown shape keep it intact |
| // (b/134092018). |
| if (!shape_proto.dim().empty()) |
| *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto; |
| |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, |
| GraphDef* output_graph_def, |
| const ItemConfig& cfg) { |
| // This is a temporary change that optimizes the graph in context of a single |
| // gpu machine. Down the line, we may want to make grappler_item_builder aware |
| // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated |
| // in order to get the correct session options and environment, and performing |
| // the correct optimizations. |
| |
| if (!cfg.apply_optimizations && !cfg.erase_noinline_attributes) { |
| return Status::OK(); |
| } |
| |
| // Create a session option for a single GPU device. |
| SessionOptions options; |
| |
| // Make a local copy of graph def, because we need to change some things. |
| GraphDef graph_def(graph_def_arg); |
| |
| if (cfg.erase_noinline_attributes) { |
| // TF optimizer doesn't inline functions with "_noinline" attribute, |
| // so let's go over the function library and erase it. |
| for (auto& func : *graph_def.mutable_library()->mutable_function()) { |
| func.mutable_attr()->erase("_noinline"); |
| } |
| } |
| |
| // Instantiate all variables for function library runtime creation. |
| std::vector<std::unique_ptr<Device>> devices; |
| // Only CPU device is used so instead of calling DeviceFactory::AddDevices() |
| // with dummy session config, which will conflict with user defined options |
| // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU |
| // only devices. |
| DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU"); |
| TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( |
| options, "/job:localhost/replica:0/task:0", &devices)); |
| Device* cpu_device = devices[0].get(); |
| auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices)); |
| FunctionLibraryDefinition function_library(OpRegistry::Global(), |
| graph_def.library()); |
| Env* env = Env::Default(); |
| |
| // Optimizer options: L1 and inlining. L1 is default. |
| OptimizerOptions* optimizer_opts = |
| options.config.mutable_graph_options()->mutable_optimizer_options(); |
| if (cfg.apply_optimizations) { |
| optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1); |
| } else { |
| optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0); |
| } |
| optimizer_opts->set_do_function_inlining(cfg.inline_functions); |
| |
| // Create the function library runtime. |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( |
| new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config, |
| graph_def.versions().producer(), |
| &function_library, *optimizer_opts)); |
| FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name()); |
| |
| // Create the GraphOptimizer to optimize the graph def. |
| GraphConstructorOptions graph_ctor_opts; |
| graph_ctor_opts.allow_internal_ops = true; |
| graph_ctor_opts.expect_device_spec = false; |
| std::unique_ptr<Graph> graphptr(new Graph(function_library)); |
| |
| TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
| graph_ctor_opts, std::move(graph_def), graphptr.get())); |
| |
| // Optimize the graph. |
| ::tensorflow::GraphOptimizer optimizer(*optimizer_opts); |
| optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr); |
| graphptr->ToGraphDef(output_graph_def); |
| |
| // The default values of attributes might have been stripped by the optimizer. |
| // Add them back. |
| return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(), |
| 0, true); |
| } |
| |
| std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( |
| const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) { |
| if (id.empty()) { |
| LOG(ERROR) << "id must be non-empty."; |
| return nullptr; |
| } |
| std::unique_ptr<GrapplerItem> new_item(new GrapplerItem()); |
| new_item->id = id; |
| new_item->graph = meta_graph.graph_def(); |
| |
| // Fill in feed nodes from config, if any provided. |
| for (const auto& feed_node : cfg.feed_nodes) { |
| const string feed_name = NodeName(feed_node); |
| new_item->feed.emplace_back(feed_name, Tensor()); |
| } |
| for (const auto& fetch_node : cfg.fetch_nodes) { |
| new_item->fetch.emplace_back(NodeName(fetch_node)); |
| } |
| |
| // Attempt to detect the fetch node(s) if they were not set explicitly. |
| if (new_item->fetch.empty() && |
| meta_graph.collection_def().count("train_op") > 0) { |
| const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); |
| if (nodes.has_node_list()) { |
| for (const auto& node : nodes.node_list().value()) { |
| new_item->fetch.push_back(NodeName(node)); |
| } |
| } |
| } |
| |
| // Detect feed and fetch nodes from signature defs. Signatures may share same |
| // inputs or outputs. |
| std::unordered_set<string> signature_feed_nodes; |
| std::unordered_set<string> signature_fetch_nodes; |
| for (const auto& name_and_signature : meta_graph.signature_def()) { |
| for (const auto& name_and_input : name_and_signature.second.inputs()) { |
| const TensorInfo& input = name_and_input.second; |
| if (input.has_coo_sparse()) { |
| // Define the shapes following the comment of CooSparse. |
| // TODO(yuefengz): we probably want to use different dim values for the |
| // three tensors of a SparseTensor. |
| int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim); |
| TensorShape shape_1d({dim}); |
| TensorShape shape_2d({dim, dim}); |
| |
| if (gtl::InsertIfNotPresent( |
| &signature_feed_nodes, |
| NodeName(input.coo_sparse().values_tensor_name()))) { |
| Tensor value_tensor(input.dtype(), shape_1d); |
| InitializeTensor(input.dtype(), &value_tensor); |
| new_item->feed.emplace_back( |
| NodeName(input.coo_sparse().values_tensor_name()), value_tensor); |
| } |
| if (gtl::InsertIfNotPresent( |
| &signature_feed_nodes, |
| NodeName(input.coo_sparse().indices_tensor_name()))) { |
| Tensor indices_tensor(DT_INT64, shape_2d); |
| InitializeTensor(input.dtype(), &indices_tensor); |
| new_item->feed.emplace_back( |
| NodeName(input.coo_sparse().indices_tensor_name()), |
| indices_tensor); |
| } |
| if (gtl::InsertIfNotPresent( |
| &signature_feed_nodes, |
| NodeName(input.coo_sparse().dense_shape_tensor_name()))) { |
| Tensor dense_shape_tensor(DT_INT64, shape_1d); |
| InitializeTensor(input.dtype(), &dense_shape_tensor); |
| new_item->feed.emplace_back( |
| NodeName(input.coo_sparse().dense_shape_tensor_name()), |
| dense_shape_tensor); |
| } |
| } else { |
| if (gtl::InsertIfNotPresent(&signature_feed_nodes, |
| NodeName(input.name()))) { |
| TensorShape shape; |
| TensorShapeProto shape_proto; |
| Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), |
| &shape_proto, &shape); |
| if (!s.ok()) { |
| LOG(ERROR) << "Invalid shape for signature input " << input.name() |
| << ": " << s << ", skipping this input"; |
| return nullptr; |
| } |
| |
| Tensor fake_input(input.dtype(), shape); |
| InitializeTensor(input.dtype(), &fake_input); |
| new_item->feed.emplace_back(NodeName(input.name()), fake_input); |
| } |
| } |
| } |
| for (const auto& name_and_output : name_and_signature.second.outputs()) { |
| const TensorInfo& output = name_and_output.second; |
| if (output.has_coo_sparse()) { |
| if (gtl::InsertIfNotPresent( |
| &signature_fetch_nodes, |
| NodeName(output.coo_sparse().values_tensor_name()))) { |
| new_item->fetch.push_back( |
| NodeName(output.coo_sparse().values_tensor_name())); |
| } |
| if (gtl::InsertIfNotPresent( |
| &signature_fetch_nodes, |
| NodeName(output.coo_sparse().indices_tensor_name()))) { |
| new_item->fetch.push_back( |
| NodeName(output.coo_sparse().indices_tensor_name())); |
| } |
| if (gtl::InsertIfNotPresent( |
| &signature_fetch_nodes, |
| NodeName(output.coo_sparse().dense_shape_tensor_name()))) { |
| new_item->fetch.push_back( |
| NodeName(output.coo_sparse().dense_shape_tensor_name())); |
| } |
| } else { |
| if (gtl::InsertIfNotPresent(&signature_fetch_nodes, |
| NodeName(output.name()))) { |
| new_item->fetch.push_back(NodeName(output.name())); |
| } |
| } |
| } |
| } |
| |
| for (const auto& feed : new_item->feed) { |
| if (feed.first.empty()) { |
| LOG(ERROR) << "Invalid feed node name skipping this input"; |
| return nullptr; |
| } else { |
| VLOG(1) << "Will use feed node " << feed.first; |
| } |
| } |
| |
| for (const auto& fetch : new_item->fetch) { |
| if (fetch.empty()) { |
| LOG(ERROR) << "Invalid fetch node name skipping this input"; |
| return nullptr; |
| } else { |
| VLOG(1) << "Will use fetch node " << fetch; |
| } |
| } |
| |
| if (new_item->fetch.empty()) { |
| LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input"; |
| return nullptr; |
| } |
| |
| // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op. |
| // The reason why they are difficult to handle is because they may not intend |
| // to initialize all variables that are required to run fetch nodes. We may |
| // have to run restore op first. |
| |
| // Try to find initializers from variables and tables as init ops. |
| for (const string& var_collection : |
| {"variables", "local_variables", "model_variables", |
| "trainable_variables"}) { |
| if (meta_graph.collection_def().count(var_collection) == 0) { |
| continue; |
| } |
| const CollectionDef& vars = meta_graph.collection_def().at(var_collection); |
| for (const auto& raw_var : vars.bytes_list().value()) { |
| VariableDef var; |
| var.ParseFromString(raw_var); |
| if (!var.initializer_name().empty()) { |
| new_item->init_ops.push_back(NodeName(var.initializer_name())); |
| } |
| } |
| } |
| |
| if (meta_graph.collection_def().count("table_initializer") > 0) { |
| const CollectionDef& inits = |
| meta_graph.collection_def().at("table_initializer"); |
| if (inits.has_node_list()) { |
| for (const auto& node : inits.node_list().value()) { |
| new_item->init_ops.push_back(NodeName(node)); |
| // Tables are initialized from files, which can take a long time. Add |
| // 30 minutes to the initialization time for each table to avoid |
| // timing out. |
| // TODO(bsteiner): adjust the timeout based on the file size. |
| new_item->expected_init_time += 30 * 60; |
| } |
| } |
| } |
| |
| // We keep the mapping from asset node to asset files. This should have been |
| // used as feed but since asset node is usually a constant node, we will fill |
| // the values of these constant nodes with their actual asset file paths. |
| std::unordered_map<string, string> asset_node_to_value; |
| |
| // Assets file may have changed their directory, we assemble their new paths |
| // if assets_directory_override is set. We also make sure we still can |
| // access these asset files. |
| if (!cfg.assets_directory_override.empty()) { |
| if (meta_graph.collection_def().count("saved_model_assets") > 0) { |
| const CollectionDef& collection = |
| meta_graph.collection_def().at("saved_model_assets"); |
| const auto& any_assets = collection.any_list().value(); |
| for (const auto& any_asset : any_assets) { |
| AssetFileDef asset_file_def; |
| if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef") |
| .ok()) { |
| LOG(ERROR) << "Failed to parse AssetFile."; |
| continue; |
| } |
| string asset_filepath = io::JoinPath(cfg.assets_directory_override, |
| asset_file_def.filename()); |
| if (!FilesExist({asset_filepath}, nullptr)) { |
| LOG(ERROR) << "Can't access one or more of the asset files " |
| << asset_filepath << ", skipping this input"; |
| return nullptr; |
| } |
| asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] = |
| asset_filepath; |
| } |
| } |
| } else if (meta_graph.collection_def().count("asset_filepaths") > 0) { |
| const CollectionDef& file_paths = |
| meta_graph.collection_def().at("asset_filepaths"); |
| std::vector<string> paths; |
| for (const auto& raw_path : file_paths.bytes_list().value()) { |
| paths.push_back(raw_path); |
| } |
| if (!FilesExist(paths, nullptr)) { |
| LOG(ERROR) << "Can't access one or more of the asset files, skipping " |
| "this input"; |
| return nullptr; |
| } |
| } |
| |
| if (meta_graph.collection_def().count("queue_runners") > 0) { |
| const CollectionDef& vars = meta_graph.collection_def().at("queue_runners"); |
| for (const auto& raw : vars.bytes_list().value()) { |
| QueueRunnerDef queue_runner; |
| if (!queue_runner.ParseFromString(raw)) { |
| LOG(ERROR) << "Could not parse queue_runners, skipping this input"; |
| return nullptr; |
| } |
| if (queue_runner.cancel_op_name().empty()) { |
| LOG(ERROR) << "Queue without a cancel op, skipping this input"; |
| return nullptr; |
| } |
| new_item->queue_runners.push_back(queue_runner); |
| } |
| } |
| |
| // Add each node referenced in a collection to the list of nodes to keep. |
| for (const auto& col : meta_graph.collection_def()) { |
| const CollectionDef& collection = col.second; |
| for (const string& node : collection.node_list().value()) { |
| new_item->keep_ops.push_back(NodeName(node)); |
| } |
| } |
| |
| for (auto& node : *new_item->graph.mutable_node()) { |
| if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") { |
| Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes, |
| new_item.get(), &node); |
| if (!s.ok()) return nullptr; |
| } else if (IsConstant(node)) { |
| auto it = asset_node_to_value.find(node.name()); |
| if (it != asset_node_to_value.end()) { |
| auto iter = node.mutable_attr()->find("value"); |
| if (iter == node.attr().end()) { |
| LOG(ERROR) << "Value attribute expected in const op for asset files"; |
| return nullptr; |
| } |
| if (!iter->second.has_tensor() || |
| iter->second.tensor().string_val_size() != 1) { |
| LOG(INFO) << "Unexpected AttrValue proto: " |
| << iter->second.DebugString(); |
| return nullptr; |
| } |
| LOG(INFO) << "Using asset file " << it->second << " for node " |
| << node.name(); |
| *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second; |
| } |
| } |
| |
| // Erase the recorded result of any previous shape inference to start again |
| // from scratch. |
| node.mutable_attr()->erase("_output_shapes"); |
| |
| // Delete user specified placement if requested. |
| if (cfg.ignore_user_placement) { |
| node.clear_device(); |
| } |
| // Delete colocation constraints if requested. |
| if (cfg.ignore_colocation) { |
| auto attr = node.mutable_attr(); |
| auto it = attr->find("_class"); |
| if (it != attr->end()) { |
| attr->erase(it); |
| } |
| } |
| } |
| |
| if (meta_graph.collection_def().count("savers") > 0) { |
| const CollectionDef& savers = meta_graph.collection_def().at("savers"); |
| for (const auto& raw : savers.bytes_list().value()) { |
| SaverDef saver; |
| // Skip bad savers since we don't need saves/restores to be able to run a |
| // graph. |
| if (!saver.ParseFromString(raw)) { |
| continue; |
| } |
| if (saver.filename_tensor_name().empty()) { |
| continue; |
| } |
| new_item->save_op = saver.save_tensor_name(); |
| new_item->restore_op = saver.restore_op_name(); |
| new_item->save_restore_loc_tensor = saver.filename_tensor_name(); |
| // Only use the first saver since it's not clear what to do if there's |
| // more than one. |
| break; |
| } |
| } else { |
| const SaverDef& saver = meta_graph.saver_def(); |
| new_item->save_op = saver.save_tensor_name(); |
| new_item->restore_op = saver.restore_op_name(); |
| new_item->save_restore_loc_tensor = saver.filename_tensor_name(); |
| } |
| |
| // Instantiate all the missing attributes with their default values. |
| Status attr_status = AddDefaultAttrsToGraphDef( |
| &new_item->graph, |
| FunctionLibraryDefinition(OpRegistry::Global(), |
| new_item->graph.library()), |
| 0, true); |
| if (!attr_status.ok()) { |
| LOG(ERROR) << "Failed to instantiate default attribute values: " |
| << attr_status.error_message(); |
| return nullptr; |
| } |
| |
| // Optimize the graph (function inlining, l1 optimizations, etc). |
| VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: " |
| << new_item->graph.node_size(); |
| Status optimize_status = |
| RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg); |
| if (!optimize_status.ok()) { |
| LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; |
| return nullptr; |
| } |
| VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: " |
| << new_item->graph.node_size(); |
| |
| if (cfg.prune_graph) { |
| VLOG(1) << "Pruning graph..."; |
| auto status = PruneGraph(new_item.get()); |
| if (!status.ok()) { |
| LOG(ERROR) << "Pruning failed: " << status.error_message(); |
| return nullptr; |
| } |
| VLOG(1) << "Number of nodes in graph after pruning: " |
| << new_item->graph.node_size(); |
| } |
| |
| // Validate feed, fetch and init nodes |
| std::unordered_set<string> nodes; |
| for (const auto& node : new_item->graph.node()) { |
| nodes.insert(node.name()); |
| } |
| for (const auto& feed : new_item->feed) { |
| if (nodes.find(feed.first) == nodes.end()) { |
| LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph"; |
| return nullptr; |
| } |
| } |
| for (const auto& fetch : new_item->fetch) { |
| if (nodes.find(fetch) == nodes.end()) { |
| LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph"; |
| return nullptr; |
| } |
| } |
| for (const auto& init : new_item->init_ops) { |
| if (nodes.find(init) == nodes.end()) { |
| LOG(ERROR) << "Init node " << init << " doesn't exist in graph"; |
| return nullptr; |
| } |
| } |
| return new_item; |
| } |
| |
| std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile( |
| const string& id, const string& meta_graph_file, const ItemConfig& cfg) { |
| MetaGraphDef meta_graph; |
| if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) { |
| LOG(ERROR) << "Failed to read " << meta_graph_file; |
| return nullptr; |
| } |
| return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg); |
| } |
| |
| } // end namespace grappler |
| } // end namespace tensorflow |