| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" |
| |
| #include <algorithm> |
| #include <deque> |
| #include <stack> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "absl/memory/memory.h" |
| #include "absl/types/optional.h" |
| #include "tensorflow/compiler/jit/union_find.h" |
| #include "tensorflow/compiler/tf2xla/functionalize_cond.h" |
| #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" |
| #include "tensorflow/compiler/tf2xla/functionalize_while.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/graph_optimizer.h" |
| #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
| #include "tensorflow/core/framework/graph_to_functiondef.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/control_flow.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/public/session_options.h" |
| #include "tensorflow/core/public/version.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| |
| namespace tensorflow { |
| |
| // Transformation that converts TensorFlow's graph control flow constructs into |
| // functional equivalents. |
| Status FunctionalizeControlFlow(Graph* graph, |
| FunctionLibraryDefinition* library) { |
| VLOG(2) << "FunctionalizeControlFlow (initial): " |
| << DumpGraphToFile("functionalize_initial", *graph, library); |
| |
| // Functionalize and remove while loops from graph. |
| TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library)); |
| |
| // FunctionalizeControlFlow is invoked for every function, so the loops's |
| // bodies and conditionals that were extracted into functions will be handled |
| // in successive invocations. |
| TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); |
| |
| VLOG(2) << "FunctionalizeControlFlow (final): " |
| << DumpGraphToFile("functionalize_final", *graph, library); |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, |
| FunctionLibraryDefinition* library) { |
| FunctionDefLibrary function_lib = graph_def->library(); |
| Graph graph(OpRegistry::Global()); |
| |
| TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); |
| TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library)); |
| graph.ToGraphDef(graph_def); |
| std::swap(*graph_def->mutable_library(), function_lib); |
| return Status::OK(); |
| } |
| |
| Status FunctionalizeControlFlowForFunction( |
| const string& func_name, const string& new_func_name, |
| const protobuf::Map<string, tensorflow::AttrValue>& attrs, |
| FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, |
| std::map<string, absl::optional<string>>* canonicalized_name_to_new_name, |
| bool* modified) { |
| *modified = false; |
| |
| // Convert the function to Graph. |
| FunctionLibraryRuntime::Handle handle; |
| TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); |
| Status ret_status = Status::OK(); |
| auto cleanup_handle = gtl::MakeCleanup([&]() { |
| auto s = flr->ReleaseHandle(handle); |
| if (!s.ok()) { |
| ret_status.Update(s); |
| } |
| }); |
| const FunctionBody* body = flr->GetFunctionBody(handle); |
| Graph* g = body->graph; |
| |
| // Check if the graph has Switch or Merge node. |
| bool has_switch_or_merge = false; |
| for (Node* n : body->graph->nodes()) { |
| if (n->type_string() == "Switch" || n->type_string() == "Merge") { |
| has_switch_or_merge = true; |
| break; |
| } |
| } |
| // We cannot return here directly if the graph has no Switch/Merge. |
| // It might contain function call nodes, or If/While nodes with Switch/Merge |
| // in function body. We still need to rewrite those functions and modify |
| // corresponding nodes. |
| |
| // If any node has associated functions, functionalize them first. |
| // Gather nodes with associated functions first, because rewriting those nodes |
| // might involve node deletion/addition. Avoid modifying nodes while iterating |
| // it. |
| std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> |
| nodes_to_associated_functions; |
| for (auto* n : g->nodes()) { |
| auto associated_functions = GetAssociatedFunctions(*n, fld); |
| if (!associated_functions.empty()) { |
| nodes_to_associated_functions.push_back({n, associated_functions}); |
| } |
| } |
| for (auto iter : nodes_to_associated_functions) { |
| Node* n = iter.first; |
| auto associated_functions = iter.second; |
| for (auto& associated_function : associated_functions) { |
| string name = associated_function.func_name(); |
| string canonicalized_name = |
| Canonicalize(name, AttrSlice(&associated_function.attrs())); |
| auto iter = canonicalized_name_to_new_name->find(canonicalized_name); |
| string new_name; |
| bool function_modified; |
| if (iter != canonicalized_name_to_new_name->end()) { |
| // If we already processed this function, check if it was rewritten. If |
| // the function was rewritten, the entry will be non-empty. Otherwise |
| // the entry will be empty. |
| function_modified = iter->second.has_value(); |
| if (function_modified) { |
| new_name = iter->second.value(); |
| } |
| } else { |
| if (associated_function.type() == |
| AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { |
| // For SymbolicGradient, `name` is always "SymbolicGradient", |
| // which is not very informative. Use node name instead. |
| new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); |
| } else { |
| new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); |
| } |
| TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( |
| name, new_name, associated_function.attrs(), fld, flr, |
| canonicalized_name_to_new_name, &function_modified)); |
| if (function_modified) { |
| // If the function was rewritten, add an non-empty entry. So later we |
| // know we have processed this function, and it was rewritten into |
| // another function. |
| (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; |
| } else { |
| // If the function was not rewritten, add an empty entry. So later |
| // we know we have processed this function, and it does not need to be |
| // rewritten. |
| (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; |
| } |
| } |
| if (function_modified) { |
| *modified = true; |
| |
| // Notice that if "n" is a function call, RewriteAssociatedFunction() |
| // will delete it and create a new node instead, making "n" an invalid |
| // pointer. That's fine because in that case, associated_functions will |
| // only have one member and the loop will only run once. |
| TF_RETURN_IF_ERROR(RewriteAssociatedFunction( |
| g, n, fld, associated_function, new_name)); |
| } |
| } |
| } |
| |
| if (has_switch_or_merge) { |
| *modified = true; |
| |
| // Functionalize the function body. |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile( |
| absl::StrCat("functionalize_control_flow_before_fdef_", func_name), |
| *g, fld); |
| } |
| TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile( |
| absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, |
| fld); |
| } |
| } |
| |
| if (*modified) { |
| // Add rewritten FunctionDef into library. |
| FunctionDef functionalized_fdef; |
| TF_RETURN_IF_ERROR( |
| GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); |
| if (func_name == new_func_name) { |
| VLOG(2) << "Replacing function " << func_name; |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(new_func_name, functionalized_fdef)); |
| } else { |
| VLOG(2) << "Adding function " << new_func_name; |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); |
| } |
| } |
| |
| return ret_status; |
| } |
| |
| Status FunctionalizeControlFlowForXlaPass::Run( |
| const GraphOptimizationPassOptions& options) { |
| Graph* graph = options.graph->get(); |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile("functionalize_control_flow_before", *graph, |
| options.flib_def); |
| } |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( |
| new ProcessFunctionLibraryRuntime( |
| /*device_mgr=*/nullptr, options.session_options->env, |
| TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); |
| FunctionLibraryRuntime* flr = |
| pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); |
| |
| // Find XLA compile ops and its corresponding FunctionDef. |
| // TPUCompile op is not in the map because graph rewriting might happen |
| // multiple times, and we want to avoid functionalize it again. |
| static std::map<string, string>* kNodeTypeToFunctionAttrMapping = |
| new std::map<string, string>{ |
| // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass. |
| {"_TPUReplicate", "computation"}, |
| // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. |
| {"XlaLaunch", "function"}, |
| }; |
| std::map<string, absl::optional<string>> canonicalized_name_to_new_name; |
| bool fld_modified = false; |
| for (Node* n : graph->nodes()) { |
| auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); |
| if (it == kNodeTypeToFunctionAttrMapping->end()) { |
| continue; |
| } |
| const string func_attr = it->second; |
| NameAttrList func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); |
| VLOG(2) << "Graph has node " << n->type_string() |
| << ". Corresponding function: " << func.name(); |
| string new_func_name = options.flib_def->UniqueFunctionName( |
| absl::StrCat(func.name(), "_f15n_")); |
| bool modified; |
| TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( |
| func.name(), new_func_name, func.attr(), options.flib_def, flr, |
| &canonicalized_name_to_new_name, &modified)); |
| if (modified) { |
| n->ClearAttr(func_attr); |
| func.set_name(new_func_name); |
| n->AddAttr(func_attr, func); |
| fld_modified = true; |
| } |
| } |
| |
| // TODO(ylc, endlessroad): Change this to "if (fld_modified")" |
| if (false) { |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile("functionalize_control_flow_before_prune", *graph, |
| options.flib_def); |
| } |
| TF_RETURN_IF_ERROR( |
| PruneUnreachableFunctionsFromGraph(*graph, options.flib_def)); |
| } |
| |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile("functionalize_control_flow_after", *graph, |
| options.flib_def); |
| } |
| return Status::OK(); |
| } |
| |
| Status FunctionalizeControlFlowPass::Run( |
| const GraphOptimizationPassOptions& options) { |
| return FunctionalizeControlFlow(options.graph->get(), options.flib_def); |
| } |
| |
| } // namespace tensorflow |