| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h" |
| |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op_def_builder.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/strcat.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| |
| namespace tensorflow { |
| |
| // TODO(mdan): Move this into Grappler - cleaner interface. |
| Status ControlFlowDepsToChainsPass::Run( |
| const GraphOptimizationPassOptions& options) { |
| VLOG(1) << "ControlFlowDepsToChainsPass::Run"; |
| |
| if (options.graph == nullptr) { |
| VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted"; |
| return Status::OK(); |
| } |
| |
| Graph* g = options.graph->get(); |
| DCHECK(g != nullptr); |
| FunctionLibraryDefinition* flib_def = options.flib_def; |
| DCHECK(flib_def != nullptr); |
| |
| if (VLOG_IS_ON(1)) { |
| DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def); |
| } |
| |
| for (Node* n : g->nodes()) { |
| if (n == nullptr) continue; |
| if (!n->IsWhileNode()) continue; |
| |
| // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed? |
| // TODO(mdan): Consolidate this with AddWhileInputHack. |
| NodeDef* while_node = n->mutable_def(); |
| const auto& attrs = while_node->attr(); |
| auto* mattrs = while_node->mutable_attr(); |
| |
| string body_name = attrs.at("body").func().name(); |
| auto* body_graph = flib_def->Find(body_name); |
| DCHECK(body_graph != nullptr); |
| |
| // Look for required annotations. |
| |
| if (attrs.find("_stateful_parallelism") == attrs.end()) continue; |
| if (!attrs.at("_stateful_parallelism").b()) continue; |
| // TODO(mdan): We don't really need this attribute. |
| if (attrs.find("_num_original_outputs") == attrs.end()) continue; |
| int body_barrier_loc = -1; |
| std::map<string, int> node_index; |
| for (int i = 0, s = body_graph->node_def_size(); i < s; i++) { |
| node_index.emplace(body_graph->node_def(i).name(), i); |
| if (body_barrier_loc < 0) { |
| const auto& node_attr = body_graph->node_def(i).attr(); |
| if (node_attr.find("_acd_function_control_output") != node_attr.end()) { |
| body_barrier_loc = i; |
| } |
| } |
| } |
| if (body_barrier_loc < 0) continue; |
| bool ok_for_lowering = true; |
| for (int i = 0; i < body_graph->control_ret_size(); i++) { |
| const auto& control_node = body_graph->node_def( |
| node_index[body_graph->signature().control_output(i)]); |
| const auto& control_attr = control_node.attr(); |
| if (control_attr.find("_res_first_used_by") == control_attr.end()) { |
| ok_for_lowering = false; |
| break; |
| } |
| } |
| if (!ok_for_lowering) continue; |
| |
| int num_loop_vars = body_graph->signature().input_arg_size(); |
| int num_new_chains = body_graph->control_ret_size(); |
| int num_node_inputs = while_node->input_size(); |
| |
| if (!num_new_chains) continue; // Nothing to do for stateless loops. |
| |
| // Add extra loop vars to the while node. |
| |
| // TODO(mdan): If the loop vars contains the resource, we should reuse it. |
| // Note that stateful ops of resource inputs cause their resources to be |
| // captured into the loop vars (through the body/cond captures). We could |
| // effectively use those as chains. |
| |
| // TODO(mdan): Is there a more efficient way to do this? |
| // Insert the new While node inputs: at the end of the loop vars, but before |
| // any non-loop var inputs (like control dependencies). Once the initial |
| // chain values are created below, they will be added to these inputs. |
| for (int i = 0; i < num_new_chains; i++) { |
| while_node->add_input(); |
| } |
| for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) { |
| while_node->set_input(i + num_new_chains, while_node->input(i)); |
| } |
| |
| std::vector<Node*> new_inputs; |
| std::vector<int> new_input_locations; |
| // Set their name to a gensym, type to float and shape to scalar. |
| for (int i = 0; i < num_new_chains; i++) { |
| string c_name = g->NewName("acd__chain"); |
| |
| // The initial value for the i'th chain loop var. |
| NodeDef new_in; |
| new_in.set_name(c_name); |
| new_in.set_op("Const"); |
| AttrValue att_dtype; |
| att_dtype.set_type(DT_FLOAT); |
| new_in.mutable_attr()->insert({"dtype", att_dtype}); |
| AttrValue att_value; |
| att_value.mutable_tensor()->set_dtype(DT_FLOAT); |
| att_value.mutable_tensor()->mutable_tensor_shape(); |
| att_value.mutable_tensor()->add_int_val(0); |
| new_in.mutable_attr()->insert({"value", att_value}); |
| Status status; |
| new_inputs.push_back(g->AddNode(new_in, &status)); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name); |
| |
| int loc = num_loop_vars + i; |
| new_input_locations.push_back(loc); |
| while_node->set_input(loc, c_name); |
| mattrs->at("T").mutable_list()->add_type(DT_FLOAT); |
| mattrs->at("output_shapes").mutable_list()->add_shape(); |
| } |
| |
| // TODO(mdan): This should not be necessary to update. Delete? |
| mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains); |
| n->UpdateProperties(); |
| for (int i = 0; i < num_new_chains; i++) { |
| g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]); |
| } |
| |
| // TODO(mdan): This is wasteful. Can we just mutate the original proto? |
| FunctionDef modified_body = *body_graph; |
| |
| // Disable the global end-of-body barrier from the body function. |
| // Because removing a node is too inefficient (would have to walk all the |
| // inputs of all graph nodes), we instead clear its control dependencies. |
| modified_body.mutable_node_def(body_barrier_loc)->clear_input(); |
| |
| // Add extra loop vars to the body function. |
| |
| for (int i = 0; i < num_new_chains; i++) { |
| // Input loop vars. |
| // TODO(mdan): Double check that this doesn't clash with names in body. |
| string c_name = g->NewName("acd__chain"); |
| std::replace(c_name.begin(), c_name.end(), '/', '_'); |
| auto* new_arg = modified_body.mutable_signature()->add_input_arg(); |
| new_arg->set_name(c_name); |
| new_arg->set_type(DT_FLOAT); |
| |
| // Output ops. These are copies of the inputs conditioned on the actual |
| // control outputs. |
| string c_out_name = g->NewName("acd__outchain"); |
| auto* new_out = modified_body.add_node_def(); |
| new_out->set_name(c_out_name); |
| new_out->set_op("Identity"); |
| new_out->add_input(c_name); |
| new_out->add_input( |
| strings::StrCat("^", body_graph->signature().control_output(i))); |
| AttrValue attr; |
| attr.set_type(DT_FLOAT); |
| new_out->mutable_attr()->insert({"T", attr}); |
| |
| // Output loop var declarations. |
| string c_ret_name = c_out_name; |
| std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_'); |
| auto* new_out_arg = modified_body.mutable_signature()->add_output_arg(); |
| new_out_arg->set_name(c_ret_name); |
| new_out_arg->set_type(DT_FLOAT); |
| |
| // Actual output loop vars. |
| modified_body.mutable_ret()->insert( |
| {c_ret_name, strings::StrCat(c_out_name, ":output:0")}); |
| AttrValue attr_val; |
| attr_val.mutable_list()->mutable_shape(); |
| FunctionDef_ArgAttrs arg_attrs; |
| arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val}); |
| modified_body.mutable_arg_attr()->insert({i + num_loop_vars, arg_attrs}); |
| } |
| |
| // Wire chain loop vars to the ops they need to condition. |
| |
| node_index.clear(); |
| for (int i = 0; i < modified_body.node_def_size(); i++) { |
| node_index.emplace(modified_body.node_def(i).name(), i); |
| } |
| auto& modified_sig = modified_body.signature(); |
| for (int i = 0; i < num_new_chains; i++) { |
| const auto& control_node = |
| modified_body.node_def(node_index[modified_sig.control_output(i)]); |
| for (const auto& r : |
| control_node.attr().at("_res_first_used_by").list().s()) { |
| NodeDef* first_node = modified_body.mutable_node_def(node_index[r]); |
| // This control dependency ensures proper sequencing of stateful ops |
| // upon entry into the loop body, so that they run after the ops |
| // which affected the same resource in the previous iteration. |
| first_node->add_input(strings::StrCat( |
| "^", modified_sig.input_arg(i + num_loop_vars).name())); |
| } |
| } |
| |
| // Clear body function's control returns. |
| modified_body.mutable_control_ret()->clear(); |
| |
| // Add extra loop vars to the cond function. |
| |
| // TODO(mdan): This is wasteful. Can't we just mutate the original proto? |
| string cond_name = attrs.at("cond").func().name(); |
| auto* cond_graph = flib_def->Find(cond_name); |
| DCHECK(cond_graph != nullptr); |
| FunctionDef modified_cond = *cond_graph; |
| |
| int cond_barrier_loc = -1; |
| for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) { |
| if (cond_barrier_loc < 0) { |
| const auto& node_attr = cond_graph->node_def(i).attr(); |
| if (node_attr.find("_acd_function_control_output") != node_attr.end()) { |
| cond_barrier_loc = i; |
| } |
| } |
| } |
| if (cond_barrier_loc > 0) { |
| // Disable the global end-of-body barrier from the cond function. |
| // Because removing a node is too inefficient (would have to walk all the |
| // inputs of all graph nodes), we instead clear its control dependencies. |
| modified_cond.mutable_node_def(cond_barrier_loc)->clear_input(); |
| } |
| |
| for (int i = 0; i < num_new_chains; i++) { |
| // Input loop vars. |
| // TODO(mdan): These should gate the stateful ops in the cond. |
| // Until ACD supplies the necessary information, these are dummies in this |
| // function. |
| string c_name = g->NewName("acd__chain"); |
| auto* new_arg = modified_cond.mutable_signature()->add_input_arg(); |
| new_arg->set_name(c_name); |
| new_arg->set_type(DT_FLOAT); |
| |
| // TODO(mdan): Return values on the cond function? Most likely a bug. |
| AttrValue attr_val; |
| attr_val.mutable_list()->mutable_shape(); |
| FunctionDef_ArgAttrs arg_attrs; |
| arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val}); |
| modified_cond.mutable_arg_attr()->insert({i + num_loop_vars, arg_attrs}); |
| } |
| |
| // Wire the new cond/body functions to the While node. |
| |
| string new_cond_name = g->NewName("acd__while_cond"); |
| modified_cond.mutable_signature()->set_name(new_cond_name); |
| mattrs->at("cond").mutable_func()->set_name(new_cond_name); |
| |
| string new_body_name = g->NewName("acd__while_body"); |
| modified_body.mutable_signature()->set_name(new_body_name); |
| mattrs->at("body").mutable_func()->set_name(new_body_name); |
| |
| // Commit the new functions. |
| |
| // TODO(b/183666205): One of these two should not be necessary. |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| flib_def->AddFunctionDef(modified_body, |
| flib_def->GetStackTraces(body_name)), |
| "while attaching", body_name, "to flib_def"); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| flib_def->AddFunctionDef(modified_cond, |
| flib_def->GetStackTraces(cond_name)), |
| "while attaching", cond_name, "to flib_def"); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| g->mutable_flib_def()->AddFunctionDef( |
| modified_body, flib_def->GetStackTraces(body_name)), |
| "while attaching", body_name, "to graph"); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| g->mutable_flib_def()->AddFunctionDef( |
| modified_cond, flib_def->GetStackTraces(cond_name)), |
| "while attaching", cond_name, "to grap"); |
| } |
| |
| if (VLOG_IS_ON(1)) { |
| DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Note: This needs to run before functional control flow lowering, which is 10. |
| REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9, |
| ControlFlowDepsToChainsPass); |
| |
| } // namespace tensorflow |