| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/strings/match.h" |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" |
| #include "tensorflow/compiler/jit/encapsulate_util.h" |
| #include "tensorflow/compiler/tf2xla/side_effect_util.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/graph_to_functiondef.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| // Control return mapping function for outside compilation host graphs. |
| // All nodes with kXlaHasHostTransfer attribute are control outputs. |
| absl::optional<string> HostGraphControlRetMapping(const Node* n) { |
| if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { |
| return n->name(); |
| } |
| return absl::nullopt; |
| } |
| |
| // Add a key placeholder node to the graph. The key placeholder node will be |
| // used as input for XlaRecvAtHost/XlaSendFromHost nodes. |
| xla::StatusOr<Node*> AddHostComputeKeyPlaceholder( |
| const string& xla_cluster_name, Graph* g) { |
| NodeDef key_def; |
| NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"), |
| "Placeholder"); |
| builder.Attr("dtype", DT_STRING); |
| builder.Attr("shape", PartialTensorShape({2})); |
| builder.Attr("_host_compute_call_node", xla_cluster_name); |
| Status s = builder.Finalize(&key_def); |
| if (!s.ok()) return s; |
| |
| Node* n = g->AddNode(key_def, &s); |
| if (!s.ok()) return s; |
| return n; |
| } |
| |
| // Returns if the node is a XLA computation key placeholder. |
| bool IsKeyPlaceholderNode(const Node& n) { |
| return n.type_string() == "Placeholder" && |
| absl::EndsWith(n.name(), "_key_placeholder"); |
| } |
| |
| // Returns nodes with given type. |
| std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) { |
| std::vector<Node*> result; |
| for (Node* n : g.nodes()) { |
| if (n->type_string() == type) { |
| result.push_back(n); |
| } |
| } |
| return result; |
| } |
| |
| // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`. |
| Status GetArgDataTypes(const std::vector<Node*>& arg_nodes, |
| std::vector<DataType>* recv_at_host_dtypes) { |
| recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID); |
| for (auto* n : arg_nodes) { |
| int index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); |
| DataType dtype; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); |
| (*recv_at_host_dtypes)[index] = dtype; |
| } |
| for (int i = 0; i < recv_at_host_dtypes->size(); i++) { |
| if ((*recv_at_host_dtypes)[i] == DT_INVALID) { |
| return errors::Internal("Cannot get datatype for input ", i); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Builds XlaRecvAtHost node. |
| xla::StatusOr<Node*> BuildRecvAtHostNode( |
| Graph* g, const string& oc_cluster_name, |
| const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) { |
| NodeDefBuilder recv_at_host_builder( |
| absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"), |
| "_XlaRecvAtHost"); |
| NodeDef recv_at_host_def; |
| recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); |
| // The correct device_ordinal will be inserted during replication in a |
| // subsequent rewrite. |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); |
| recv_at_host_builder.Attr( |
| "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); |
| recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); |
| recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); |
| TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); |
| Status s; |
| Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| return recv_at_host_node; |
| } |
| |
| // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it. |
| xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode( |
| Graph* g, const string& oc_cluster_name, |
| std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) { |
| // TODO(b/77601805): use out nodes for source node, instead of traversing all |
| // nodes. |
| std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg"); |
| TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes)); |
| TF_ASSIGN_OR_RETURN( |
| Node * recv_at_host_node, |
| BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes, |
| key_placeholder)); |
| for (auto* n : arg_nodes) { |
| int index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); |
| // Record out edges and remove `n` before adding those edges to RecvAtHost. |
| // This is to avoid multiple producers. |
| std::vector<OutEdgeInfo> out_edge_info; |
| for (auto edge : n->out_edges()) { |
| out_edge_info.push_back( |
| {edge->dst(), edge->src_output(), edge->dst_input()}); |
| } |
| g->RemoveNode(n); |
| for (const OutEdgeInfo& edge : out_edge_info) { |
| if (edge.dst_input == Graph::kControlSlot) { |
| g->AddControlEdge(recv_at_host_node, edge.dst); |
| } else { |
| g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input); |
| } |
| } |
| |
| // Rewrite dst nodes because their input changed. |
| for (int i = 0; i < out_edge_info.size(); i++) { |
| const OutEdgeInfo edge = out_edge_info[i]; |
| if (edge.dst_input == Graph::kControlSlot) { |
| continue; |
| } |
| |
| Node* dst = edge.dst; |
| NodeDef new_def = dst->def(); |
| *new_def.mutable_input(edge.dst_input) = |
| absl::StrCat(recv_at_host_node->name(), ":", index); |
| TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def)); |
| |
| // Other edges might have `dst` as dst node as well. Update those edges |
| // with `dst_replace`. |
| for (int j = i + 1; j < out_edge_info.size(); j++) { |
| if (out_edge_info[j].dst == dst) { |
| out_edge_info[j].dst = dst_replace; |
| } |
| } |
| } |
| } |
| g->AddEdge(key_placeholder, 0, recv_at_host_node, 0); |
| return recv_at_host_node; |
| } |
| |
| // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`. |
| Status GetRetDataTypes(const std::vector<Node*>& ret_nodes, |
| std::vector<DataType>* send_from_host_dtypes) { |
| send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID); |
| for (auto* n : ret_nodes) { |
| int index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); |
| DataType dtype; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); |
| (*send_from_host_dtypes)[index] = dtype; |
| } |
| for (int i = 0; i < send_from_host_dtypes->size(); i++) { |
| if ((*send_from_host_dtypes)[i] == DT_INVALID) { |
| return errors::Internal("Cannot get datatype for output ", i); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Builds XlaSendFromHost node. |
| xla::StatusOr<Node*> BuildSendFromHostNode( |
| Graph* g, const string& oc_cluster_name, |
| const std::vector<Node*>& ret_nodes, |
| const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) { |
| NodeDefBuilder send_from_host_builder( |
| absl::StrCat("outside_compilation_", oc_cluster_name, "_send"), |
| "_XlaSendFromHost"); |
| NodeDef send_from_host_def; |
| send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); |
| // The correct device_ordinal will be inserted during replication in a |
| // subsequent rewrite. |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| send_from_host_builder.Attr("device_ordinal", device_ordinal_value); |
| send_from_host_builder.Attr( |
| "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); |
| send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); |
| std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size()); |
| for (auto* n : ret_nodes) { |
| int index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); |
| if (index < 0 || index >= send_from_host_dtypes.size()) { |
| return errors::Internal("Invalid _Retval index: ", index); |
| } |
| for (auto edge : n->in_edges()) { |
| inputs[index] = |
| NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(), |
| edge->src()->output_type(edge->src_output())}; |
| } |
| } |
| send_from_host_builder.Input(inputs); |
| send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING); |
| TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def)); |
| Status s; |
| Node* send_from_host_node = g->AddNode(send_from_host_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| return send_from_host_node; |
| } |
| |
| // Builds XlaSendFromHost node, and replaces all _Retval nodes with it. |
| xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode( |
| Graph* g, const string& oc_cluster_name, |
| std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) { |
| // TODO(b/77601805): use in nodes for sink node, instead of traversing all |
| // nodes. |
| std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval"); |
| TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes)); |
| TF_ASSIGN_OR_RETURN( |
| Node * send_from_host_node, |
| BuildSendFromHostNode(g, oc_cluster_name, ret_nodes, |
| *send_from_host_dtypes, key_placeholder)); |
| for (auto* n : ret_nodes) { |
| int index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); |
| for (auto edge : n->in_edges()) { |
| if (edge->src_output() == Graph::kControlSlot) { |
| g->AddControlEdge(edge->src(), send_from_host_node); |
| } else { |
| g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index); |
| } |
| } |
| g->RemoveNode(n); |
| } |
| g->AddEdge(key_placeholder, 0, send_from_host_node, |
| send_from_host_dtypes->size()); |
| return send_from_host_node; |
| } |
| |
| // Returns input shapes (excluding key placeholder) for `send_from_host_node` |
| // if they are all fully defined; absl::nullopt otherwise. |
| absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes( |
| int num_inputs, Node* send_from_host_node) { |
| std::vector<PartialTensorShape> results(num_inputs); |
| for (int i = 0; i < num_inputs; i++) { |
| const Edge* e; |
| if (!send_from_host_node->input_edge(i, &e).ok()) { |
| return absl::nullopt; |
| } |
| |
| std::vector<PartialTensorShape> shapes; |
| if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes) |
| .ok()) { |
| return absl::nullopt; |
| } |
| |
| const PartialTensorShape shape = shapes[e->src_output()]; |
| if (!shape.IsFullyDefined()) { |
| return absl::nullopt; |
| } |
| |
| results[e->dst_input()] = shape; |
| } |
| return results; |
| } |
| |
| string host_compute_node_name(const string& original_oc_name) { |
| return absl::StrCat("outside_compilation_", original_oc_name, |
| "_host_compute"); |
| } |
| |
| // Builds XlaHostCompute NodeDef from the outside compilation call node. |
| xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef( |
| const Node* call_node, const std::map<string, int>& host_compute_core, |
| const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) { |
| string original_oc_name; |
| TF_RETURN_IF_ERROR(GetNodeAttr( |
| call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); |
| NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name), |
| "XlaHostCompute"); |
| |
| // Copy all attributes. |
| for (auto attr : call_node->attrs()) { |
| host_compute_builder.Attr(attr.first, attr.second); |
| } |
| |
| // Populate tpu_core assignment. |
| const auto iter = host_compute_core.find(original_oc_name); |
| if (iter != host_compute_core.end()) { |
| int core = iter->second; |
| host_compute_builder.Attr("tpu_core", core); |
| } |
| |
| // Set input tokens and other outside compilation clusters that current |
| // cluster depends in `kXlaTokenArgNodeName`. This is needed because when |
| // outside compilation subgraphs are encapsulated and moved to host graph, |
| // control/data edges between them will only be reflected in host graph. |
| // From XLA's perspective, two originally dependent clusters are no longer |
| // connected, which makes them look like they can be scheduled for execution |
| // in arbitrary order even though in fact they must be executed in order |
| // according to their host-side graph dependency. This can cause deadlock. |
| // Therefore, we hint XLA what the correct ordering of these clusters should |
| // be to avoid deadlocks. |
| std::vector<string> xla_token_input_nodes; |
| xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName); |
| auto cluster_deps_it = cluster_deps.find(original_oc_name); |
| if (cluster_deps_it != cluster_deps.end()) { |
| for (auto dep : cluster_deps_it->second) { |
| xla_token_input_nodes.emplace_back(host_compute_node_name(dep)); |
| } |
| } |
| host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes); |
| |
| // Populate inputs. |
| std::vector<DataType> input_dtypes; |
| TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); |
| std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size()); |
| for (auto e : call_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| continue; |
| } |
| |
| if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) { |
| return errors::Internal("Invalid dst_input: ", e->dst_input()); |
| } |
| inputs[e->dst_input()] = NodeDefBuilder::NodeOut{ |
| e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]}; |
| } |
| host_compute_builder.Input(inputs); |
| |
| NodeDef new_def; |
| TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def)); |
| return new_def; |
| } |
| |
| TF_ATTRIBUTE_NOINLINE Status |
| ValidateOutsideCompilationCallNode(Node* call_node) { |
| // DT_INT64 as input/output for outside compilation is not supported yet: |
| // b/120809951. |
| for (const Edge* e : call_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| continue; |
| } |
| DataType dtype = e->src()->output_type(e->src_output()); |
| if (dtype == DT_INT64) { |
| return errors::Unimplemented( |
| "int64 input for outside compilation is not supported yet: " |
| "b/120809951. Please cast output of node ", |
| e->src()->DebugString(), |
| " to int32 before feeding it into outside compilation."); |
| } |
| } |
| for (const Edge* e : call_node->out_edges()) { |
| if (e->IsControlEdge()) { |
| continue; |
| } |
| DataType dtype = e->dst()->input_type(e->dst_input()); |
| if (dtype == DT_INT64) { |
| return errors::Unimplemented( |
| "int64 output for outside compilation is not supported yet: " |
| "b/120809951. Please cast input of node ", |
| e->dst()->DebugString(), |
| " to int32 before returning it from outside compilation."); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Replace outside compilation function call node with XlaHostCompute node. |
| TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode( |
| Graph* g, Node* call_node, const std::map<string, int>& host_compute_core, |
| const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) { |
| // Build XlaHostCompute NodeDef. |
| TF_ASSIGN_OR_RETURN( |
| NodeDef node_def, |
| BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps)); |
| TF_ASSIGN_OR_RETURN(Node * host_compute_node, |
| ReplaceNode(g, call_node, node_def)); |
| VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); |
| |
| return host_compute_node; |
| } |
| |
| // Resets "_device_ordinal" attr to placeholder value for related nodes |
| // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes |
| // containing XlaRecvAtHost/XlaSendFromHost). |
| Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| for (Node* n : g->nodes()) { |
| if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { |
| continue; |
| } |
| |
| if (n->type_string() == "_XlaRecvAtHost" || |
| n->type_string() == "_XlaSendFromHost") { |
| n->ClearAttr("device_ordinal"); |
| n->AddAttr("device_ordinal", device_ordinal_value); |
| } else if (n->IsIfNode()) { |
| for (const string& attr_name : |
| std::vector<string>{"then_branch", "else_branch"}) { |
| NameAttrList branch_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); |
| (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; |
| n->ClearAttr(attr_name); |
| n->AddAttr(attr_name, branch_func); |
| } |
| } else if (n->IsWhileNode()) { |
| for (const string& attr_name : std::vector<string>{"cond", "body"}) { |
| NameAttrList branch_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); |
| (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; |
| n->ClearAttr(attr_name); |
| n->AddAttr(attr_name, branch_func); |
| } |
| } else if (HasNodeAttr(n->def(), "_device_ordinal")) { |
| // Function call node containing outside compilation. |
| n->ClearAttr("_device_ordinal"); |
| n->AddAttr("_device_ordinal", device_ordinal_value); |
| } else { |
| return errors::Internal("Unknown node marked with ", |
| kXlaHasHostTransferAttrName, ": ", |
| n->DebugString()); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Cheap check to tell whether FunctionDef contains a lifted argument. |
| bool HasLiftedArgs(const FunctionDef& function_def) { |
| return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) { |
| return (node_def.op() == "Placeholder" && |
| node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) != |
| node_def.attr().end()); |
| }); |
| } |
| |
| // Find lifted arguments in a function body and their corresponding outside |
| // compilation nodes. |
| xla::StatusOr<std::vector<std::pair<Node*, Node*>>> |
| LiftedArgsAndOutsideCompilationNodesInFunctionBody( |
| const FunctionBody& function_body, |
| const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) { |
| std::vector<std::pair<Node*, Node*>> |
| lifted_arg_nodes_and_outside_compilation_nodes; |
| for (Node* n : function_body.graph->op_nodes()) { |
| string oc_cluster; |
| if (n->type_string() == "Placeholder" && |
| GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName, |
| &oc_cluster) |
| .ok()) { |
| TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) != |
| outside_compilation_attr_to_node.end()); |
| lifted_arg_nodes_and_outside_compilation_nodes.emplace_back( |
| n, outside_compilation_attr_to_node.at(oc_cluster)); |
| } |
| } |
| return lifted_arg_nodes_and_outside_compilation_nodes; |
| } |
| |
| // Append lifted args' types to functional control flow node's `type_attr_name` |
| // attribute. |
| xla::StatusOr<std::vector<DataType>> UpdateTypesAttribute( |
| const std::vector<std::pair<Node*, Node*>>& |
| lifted_arg_nodes_and_outside_compilation_nodes, |
| const string& type_attr_name, Node* n) { |
| std::vector<DataType> data_types; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types)); |
| for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) { |
| Node* outside_compilation_node = pair.second; |
| DataType data_type; |
| TF_RET_CHECK(outside_compilation_node->IsIdentity() || |
| outside_compilation_node->type_string() == "Placeholder"); |
| if (outside_compilation_node->IsIdentity()) { |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(outside_compilation_node->def(), "T", &data_type)); |
| } else { |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type)); |
| } |
| data_types.push_back(data_type); |
| } |
| n->ClearAttr(type_attr_name); |
| n->AddAttr(type_attr_name, data_types); |
| |
| return data_types; |
| } |
| |
| // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`. |
| void AddEdgesFromOutsideCompilationNodes( |
| const int original_arg_count, const int arg_to_input_edge_offset, |
| const std::vector<DataType>& data_types, |
| const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) { |
| // Add edges from outside compilation nodes to While node. |
| for (int i = original_arg_count; i < data_types.size(); i++) { |
| Node* outside_compilation_node = |
| outside_compilation_nodes[i - original_arg_count]; |
| g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset); |
| } |
| } |
| |
| // Construct _Arg that maps to lifted outside compilation argument node input. |
| xla::StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody( |
| const FunctionBody& function_body, const int arg_idx, |
| const DataType& data_type) { |
| NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg"); |
| arg_builder.Attr("T", data_type); |
| arg_builder.Attr("index", arg_idx); |
| NodeDef arg_def; |
| TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); |
| |
| Status s; |
| Node* arg_node = function_body.graph->AddNode(arg_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| return arg_node; |
| } |
| |
| // Add _Retval node that matches newly added `arg_node` and connect `arg_node` |
| // to it. |
| Status AddMatchingRetvalNode(const FunctionBody& function_body, |
| const int arg_idx, const DataType& data_type, |
| Node* arg_node) { |
| NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval"); |
| ret_builder.Attr("T", data_type); |
| ret_builder.Attr("index", arg_idx); |
| ret_builder.Input(arg_node->name(), 0, data_type); |
| NodeDef ret_def; |
| TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); |
| Status s; |
| Node* ret_node = function_body.graph->AddNode(ret_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| function_body.graph->AddEdge(arg_node, 0, ret_node, 0); |
| |
| return Status::OK(); |
| } |
| |
| void ReplaceLiftedArgNodePlaceholderWithArg( |
| const FunctionBody& function_body, const int original_arg_count, |
| const int arg_idx, const std::vector<Node*>& lifted_arg_nodes, |
| Node* arg_node) { |
| Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count]; |
| // This might happen because lifted_arg_node only exists in one branch of an |
| // If node, and we are handling the other branch. |
| if (!lifted_arg_node) { |
| return; |
| } |
| |
| for (const Edge* e : lifted_arg_node->out_edges()) { |
| if (e->IsControlEdge()) { |
| function_body.graph->AddControlEdge(arg_node, e->dst()); |
| } else { |
| function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input()); |
| } |
| } |
| function_body.graph->RemoveNode(lifted_arg_node); |
| } |
| |
| // Reconnect outside compilation lifted arguments in a functional While node to |
| // its outside compilation tensor sources. |
| Status PostprocessLiftedArgsForWhile( |
| const std::unordered_map<string, Node*>& outside_compilation_attr_to_node, |
| Graph* g, Node* n, FunctionLibraryDefinition* fld) { |
| TF_RET_CHECK(n->IsWhileNode()); |
| |
| // Check if there is any lifted args in body function. |
| NameAttrList body_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func)); |
| const FunctionDef* body_function_def = fld->Find(body_func.name()); |
| TF_RET_CHECK(body_function_def); |
| |
| if (!HasLiftedArgs(*body_function_def)) { |
| return Status::OK(); |
| } |
| |
| // Gather all lifted args. |
| std::unique_ptr<FunctionBody> body_function_body; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def, |
| AttrSlice(&body_func.attr()), fld, |
| &body_function_body)); |
| |
| int original_arg_count = body_function_body->arg_nodes.size(); |
| |
| TF_ASSIGN_OR_RETURN( |
| auto lifted_arg_nodes_and_outside_compilation_nodes, |
| LiftedArgsAndOutsideCompilationNodesInFunctionBody( |
| *body_function_body, outside_compilation_attr_to_node)); |
| |
| // Append lifted args' types to While node's T attribute. |
| TF_ASSIGN_OR_RETURN( |
| std::vector<DataType> data_types, |
| UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T", |
| n)); |
| |
| // Add edges from outside compilation nodes to While node. |
| std::vector<Node*> outside_compilation_nodes; |
| std::transform( |
| lifted_arg_nodes_and_outside_compilation_nodes.begin(), |
| lifted_arg_nodes_and_outside_compilation_nodes.end(), |
| std::back_inserter(outside_compilation_nodes), |
| [](const std::pair<Node*, Node*>& pair) { return pair.second; }); |
| AddEdgesFromOutsideCompilationNodes(original_arg_count, |
| /*arg_to_input_edge_offset=*/0, |
| data_types, outside_compilation_nodes, g, |
| n); |
| |
| // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg |
| // nodes with the new _Arg nodes. |
| std::vector<Node*> lifted_arg_nodes; |
| std::transform( |
| lifted_arg_nodes_and_outside_compilation_nodes.begin(), |
| lifted_arg_nodes_and_outside_compilation_nodes.end(), |
| std::back_inserter(lifted_arg_nodes), |
| [](const std::pair<Node*, Node*>& pair) { return pair.first; }); |
| for (int i = original_arg_count; i < data_types.size(); i++) { |
| TF_ASSIGN_OR_RETURN(Node * arg_node, |
| AddOutsideCompilationInputArgToFunctionBody( |
| *body_function_body, i, data_types[i])); |
| |
| TF_RETURN_IF_ERROR( |
| AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node)); |
| |
| ReplaceLiftedArgNodePlaceholderWithArg( |
| *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node); |
| } |
| |
| FunctionDef rewritten_body_function_def; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef( |
| *body_function_body->graph, body_func.name(), HostGraphControlRetMapping, |
| &rewritten_body_function_def)); |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(body_func.name(), rewritten_body_function_def)); |
| |
| // In cond_graph, just add new _Arg nodes. |
| NameAttrList cond_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func)); |
| const FunctionDef* cond_function_def = fld->Find(cond_func.name()); |
| TF_RET_CHECK(cond_function_def); |
| std::unique_ptr<FunctionBody> cond_function_body; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def, |
| AttrSlice(&cond_func.attr()), fld, |
| &cond_function_body)); |
| |
| for (int i = original_arg_count; i < data_types.size(); i++) { |
| xla::StatusOr<Node*> arg_node_or = |
| AddOutsideCompilationInputArgToFunctionBody(*cond_function_body, i, |
| data_types[i]); |
| TF_RETURN_IF_ERROR(arg_node_or.status()); |
| } |
| |
| FunctionDef rewritten_cond_function_def; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef( |
| *cond_function_body->graph, cond_func.name(), HostGraphControlRetMapping, |
| &rewritten_cond_function_def)); |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(cond_func.name(), rewritten_cond_function_def)); |
| |
| return Status::OK(); |
| } |
| |
| Status PostprocessLiftedArgsForIf( |
| const std::unordered_map<string, Node*>& outside_compilation_attr_to_node, |
| Graph* g, Node* n, FunctionLibraryDefinition* fld) { |
| TF_RET_CHECK(n->IsIfNode()); |
| |
| NameAttrList then_branch_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func)); |
| const FunctionDef* then_branch_function_def = |
| fld->Find(then_branch_func.name()); |
| TF_RET_CHECK(then_branch_function_def); |
| |
| NameAttrList else_branch_func; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func)); |
| const FunctionDef* else_branch_function_def = |
| fld->Find(else_branch_func.name()); |
| TF_RET_CHECK(else_branch_function_def); |
| |
| // Nothing to do if neither branch contains any lifted arguments. |
| if (!HasLiftedArgs(*then_branch_function_def) && |
| !HasLiftedArgs(*else_branch_function_def)) { |
| return Status::OK(); |
| } |
| |
| std::unique_ptr<FunctionBody> then_branch_function_body; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld, |
| &then_branch_function_body)); |
| |
| std::unique_ptr<FunctionBody> else_branch_function_body; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld, |
| &else_branch_function_body)); |
| |
| // Then and else branches have same argument count and argument data types. |
| int original_arg_count = then_branch_function_body->arg_nodes.size(); |
| |
| TF_ASSIGN_OR_RETURN( |
| auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes, |
| LiftedArgsAndOutsideCompilationNodesInFunctionBody( |
| *then_branch_function_body, outside_compilation_attr_to_node)); |
| |
| TF_ASSIGN_OR_RETURN( |
| auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes, |
| LiftedArgsAndOutsideCompilationNodesInFunctionBody( |
| *else_branch_function_body, outside_compilation_attr_to_node)); |
| |
| // Merge lifted args from then and else branches. |
| std::vector<Node*> outside_compilation_nodes; |
| std::vector<Node*> then_branch_lifted_arg_nodes; |
| for (const auto& pair : |
| then_branch_lifted_arg_nodes_and_outside_compilation_nodes) { |
| outside_compilation_nodes.push_back(pair.second); |
| then_branch_lifted_arg_nodes.push_back(pair.first); |
| } |
| for (const auto& pair : |
| else_branch_lifted_arg_nodes_and_outside_compilation_nodes) { |
| if (std::find(outside_compilation_nodes.begin(), |
| outside_compilation_nodes.end(), |
| pair.second) == outside_compilation_nodes.end()) { |
| outside_compilation_nodes.push_back(pair.second); |
| // Then branch does not contain this lifted arg. Add an empty item to |
| // then_branch_lifted_arg_nodes. |
| then_branch_lifted_arg_nodes.push_back(nullptr); |
| } |
| } |
| // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes. |
| std::vector<Node*> else_branch_lifted_arg_nodes( |
| outside_compilation_nodes.size()); |
| for (const auto& pair : |
| else_branch_lifted_arg_nodes_and_outside_compilation_nodes) { |
| auto iter = std::find(outside_compilation_nodes.begin(), |
| outside_compilation_nodes.end(), pair.second); |
| TF_RET_CHECK(iter != outside_compilation_nodes.end()); |
| int index = iter - outside_compilation_nodes.begin(); |
| else_branch_lifted_arg_nodes[index] = pair.first; |
| } |
| |
| // Append lifted args' types to If node's Tin attribute. |
| std::vector<DataType> data_types; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types)); |
| for (Node* n : outside_compilation_nodes) { |
| data_types.push_back(n->output_type(0)); |
| } |
| n->ClearAttr("Tin"); |
| n->AddAttr("Tin", data_types); |
| |
| // Add edges from outside compilation nodes to If node. If node's input #0 |
| // is predicate input, input #1 maps to _Arg #0 of branch functions, thus |
| // arg_to_input_edge_offset is set to 1. |
| AddEdgesFromOutsideCompilationNodes(original_arg_count, |
| /*arg_to_input_edge_offset=*/1, |
| data_types, outside_compilation_nodes, g, |
| n); |
| |
| for (int i = original_arg_count; i < data_types.size(); ++i) { |
| TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node, |
| AddOutsideCompilationInputArgToFunctionBody( |
| *then_branch_function_body, i, data_types[i])); |
| |
| ReplaceLiftedArgNodePlaceholderWithArg( |
| *then_branch_function_body, original_arg_count, i, |
| then_branch_lifted_arg_nodes, then_branch_arg_node); |
| |
| TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node, |
| AddOutsideCompilationInputArgToFunctionBody( |
| *else_branch_function_body, i, data_types[i])); |
| |
| ReplaceLiftedArgNodePlaceholderWithArg( |
| *else_branch_function_body, original_arg_count, i, |
| else_branch_lifted_arg_nodes, else_branch_arg_node); |
| } |
| |
| FunctionDef rewritten_then_branch_function_def; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef( |
| *then_branch_function_body->graph, then_branch_func.name(), |
| HostGraphControlRetMapping, &rewritten_then_branch_function_def)); |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(), |
| rewritten_then_branch_function_def)); |
| |
| FunctionDef rewritten_else_branch_function_def; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef( |
| *else_branch_function_body->graph, else_branch_func.name(), |
| HostGraphControlRetMapping, &rewritten_else_branch_function_def)); |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(), |
| rewritten_else_branch_function_def)); |
| return Status::OK(); |
| } |
| |
| Status PostprocessLiftedArgsForCall( |
| const std::unordered_map<string, Node*>& outside_compilation_attr_to_node, |
| Graph* g, Node* n, FunctionLibraryDefinition* fld) { |
| const FunctionDef* fdef = fld->Find(n->type_string()); |
| TF_RET_CHECK(fdef); |
| |
| // Nothing to do if the function does not contain any lifted arguments. |
| if (!HasLiftedArgs(*fdef)) { |
| return Status::OK(); |
| } |
| |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody)); |
| |
| int original_arg_count = fbody->arg_nodes.size(); |
| |
| TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes, |
| LiftedArgsAndOutsideCompilationNodesInFunctionBody( |
| *fbody, outside_compilation_attr_to_node)); |
| |
| // Append lifted args' types to call node's input data types. |
| std::vector<DataType> data_types(n->input_types().begin(), |
| n->input_types().end()); |
| for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) { |
| Node* outside_compilation_node = pair.second; |
| DataType data_type; |
| TF_RET_CHECK(outside_compilation_node->IsIdentity() || |
| outside_compilation_node->type_string() == "Placeholder"); |
| if (outside_compilation_node->IsIdentity()) { |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(outside_compilation_node->def(), "T", &data_type)); |
| } else { |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type)); |
| } |
| data_types.push_back(data_type); |
| } |
| |
| std::vector<Node*> lifted_arg_nodes; |
| std::transform( |
| lifted_arg_nodes_and_outside_compilation_nodes.begin(), |
| lifted_arg_nodes_and_outside_compilation_nodes.end(), |
| std::back_inserter(lifted_arg_nodes), |
| [](const std::pair<Node*, Node*>& pair) { return pair.first; }); |
| for (int i = original_arg_count; i < data_types.size(); ++i) { |
| TF_ASSIGN_OR_RETURN( |
| Node * arg_node, |
| AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i])); |
| |
| ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i, |
| lifted_arg_nodes, arg_node); |
| } |
| |
| FunctionDef rewritten_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(), |
| HostGraphControlRetMapping, |
| &rewritten_fdef)); |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef)); |
| |
| // We need to recreate the node. Otherwise TF will not know n->num_inputs() |
| // has increased. |
| NodeDef node_def = n->def(); |
| for (int i = original_arg_count; i < data_types.size(); i++) { |
| Node* outside_compilation_node = |
| lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] |
| .second; |
| node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0)); |
| } |
| TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def)); |
| |
| // Add edges from outside compilation nodes to call node. |
| std::vector<Node*> outside_compilation_nodes; |
| std::transform( |
| lifted_arg_nodes_and_outside_compilation_nodes.begin(), |
| lifted_arg_nodes_and_outside_compilation_nodes.end(), |
| std::back_inserter(outside_compilation_nodes), |
| [](const std::pair<Node*, Node*>& pair) { return pair.second; }); |
| AddEdgesFromOutsideCompilationNodes(original_arg_count, |
| /*arg_to_input_edge_offset=*/0, |
| data_types, outside_compilation_nodes, g, |
| n); |
| |
| return Status::OK(); |
| } |
| |
| // Creates a mapping from outside compilation cluster name to lifted argument |
| // placeholder. |
| xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode( |
| const Graph& g) { |
| std::unordered_map<string, Node*> outside_compilation_attr_to_node; |
| for (Node* n : g.op_nodes()) { |
| bool is_lifted_arg; |
| string outside_compilation_attr; |
| if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) && |
| TryGetNodeAttr(n->def(), "_xla_outside_compilation", |
| &outside_compilation_attr)) { |
| TF_RET_CHECK(is_lifted_arg); |
| TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder"); |
| outside_compilation_attr_to_node[outside_compilation_attr] = n; |
| } |
| } |
| |
| return outside_compilation_attr_to_node; |
| } |
| |
| Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { |
| TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node, |
| OutsideCompilationAttrToNode(*g)); |
| |
| std::vector<Node*> call_nodes; |
| for (Node* n : g->op_nodes()) { |
| if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { |
| continue; |
| } |
| |
| if (n->IsWhileNode()) { |
| TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile( |
| outside_compilation_attr_to_node, g, n, fld)); |
| } |
| |
| if (n->IsIfNode()) { |
| TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf( |
| outside_compilation_attr_to_node, g, n, fld)); |
| } |
| |
| // Outside compilation host side function call will always be direct |
| // function call nodes. |
| // Function call nodes need to be handled separately because we rewrite |
| // nodes in `PostprocessLiftedArgsForCall`. |
| if (fld->Contains(n->type_string())) { |
| call_nodes.push_back(n); |
| } |
| } |
| |
| for (Node* n : call_nodes) { |
| TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall( |
| outside_compilation_attr_to_node, g, n, fld)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // For an XLA computation, builds host side graph given all outside compilation |
| // graphs inside it. The host side graph contains: |
| // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and |
| // XlaSendFromHost to this sequencer node, so all outside compilation nodes |
| // will be executed *before* this sequencer). |
| // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will |
| // replace this node with compilation result node. |
| // 3) all outside compilation graphs. |
| Status ConstructHostGraph( |
| const string& xla_cluster_name, const string& outside_compilation_attr_name, |
| const std::vector<string>& outside_compilation_host_graphs, |
| FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) { |
| host_graph->reset(new Graph(fld)); |
| |
| // Create sequencer node in host graph. |
| NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), |
| "NoOp"); |
| sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name); |
| NodeDef sequencer_def; |
| TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); |
| Status s; |
| Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| |
| // Create key placeholder in host graph. |
| TF_ASSIGN_OR_RETURN( |
| Node * key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); |
| |
| // For each outside compilation graph, copy them to host graph with the |
| // following changes: |
| // a) Use key_placeholder in host graph instead of its own. |
| // b) Add control edge from host transfer nodes (XlaRecvAtHost, |
| // XlaSendFromHost, If/While nodes containing |
| // XlaRecvAtHost/XlaSendFromHost) to sequencer node. |
| // c) Clear node_def.device(), so device placer won't get confused. |
| for (const string& host_func : outside_compilation_host_graphs) { |
| VLOG(4) << "Expanding host graph " << host_func; |
| // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder |
| // value after we expanded all host graphs. We cannot just use placeholder |
| // value here because FunctionDef instantiation does not allow placeholder |
| // value for attributes. |
| AttrValue device_ordinal_attr; |
| device_ordinal_attr.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_attr; |
| std::unique_ptr<FunctionBody> host_fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *fld->Find(host_func), AttrSlice(&attrs), fld, &host_fbody)); |
| |
| // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse |
| // reachable from sink node so all nodes will be copied. |
| // TODO(b/77601805): consolidate copy graph functions. |
| FixupSourceAndSinkEdges(host_fbody->graph); |
| |
| std::map<const Node*, Node*> node_map; |
| node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); |
| node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); |
| Status s; |
| ReverseDFS( |
| *host_fbody->graph, /*enter=*/nullptr, |
| [&](const Node* n) { |
| if (!s.ok()) { |
| return; |
| } |
| |
| Node* copy; |
| if (node_map.find(n) != node_map.end()) { |
| // Already copied this node. |
| copy = node_map.at(n); |
| } else if (IsKeyPlaceholderNode(*n)) { |
| // Change a). |
| copy = key_placeholder; |
| node_map[n] = copy; |
| } else { |
| // Copy the node. |
| NodeDef copy_def = n->def(); |
| // Change c). |
| copy_def.clear_device(); |
| copy = (*host_graph)->AddNode(copy_def, &s); |
| if (!s.ok()) { |
| return; |
| } |
| node_map[n] = copy; |
| } |
| |
| // Only handle input edges. Output edges will be added later as |
| // its output nodes' input edges. |
| for (auto e : n->in_edges()) { |
| if (node_map.find(e->src()) == node_map.end()) { |
| s = errors::Internal("Cannot find node image for ", |
| e->src()->DebugString()); |
| return; |
| } |
| (*host_graph) |
| ->AddEdge(node_map[e->src()], e->src_output(), copy, |
| e->dst_input()); |
| } |
| |
| // Change b). |
| if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { |
| (*host_graph)->AddControlEdge(copy, sequencer); |
| } |
| }, |
| NodeComparatorID()); |
| |
| if (!s.ok()) { |
| return s; |
| } |
| } |
| // Reset "_device_ordinal" to placeholder value. |
| TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get())); |
| |
| // sequencer and key_placeholder might be dead nodes. Prune them if necessary. |
| // - sequencer should be pruned iff it has no input control edges from |
| // RecvAtHost/SendFromHost. If it has input control edge, we connect it to |
| // sink node so it won't be pruned. |
| // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. |
| // We don't need to do anything special. |
| if (!sequencer->in_edges().empty()) { |
| (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); |
| } |
| PruneForReverseReachability( |
| host_graph->get(), |
| std::unordered_set<const Node*>{(*host_graph)->sink_node()}); |
| |
| // Postprocess edges between different outside compilations. |
| TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( |
| host_graph->get(), outside_compilation_attr_name)); |
| |
| // Postprocess lifted arg nodes. |
| TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld)); |
| |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_", |
| xla_cluster_name), |
| **host_graph, fld); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Expand XLA computation's outside compilation host side graph into main graph. |
| // Add a control edge between sequencer node and the XLA computation node. |
| Status ExpandHostGraphIntoMainGraph(Graph* main_graph, |
| FunctionLibraryDefinition* fld, |
| const string& host_graph_func_name, |
| Node* xla_computation_node, |
| Node* pivot_node) { |
| // Temporarily use "0" as "_device_ordinal". It will be rewritten with the |
| // correct value in a later pass. We cannot just use placeholder value here |
| // because FunctionDef instantiation does not allow placeholder value for |
| // attributes. |
| AttrValue device_ordinal_attr; |
| device_ordinal_attr.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_attr; |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(host_graph_func_name), |
| AttrSlice(&attrs), fld, &fbody)); |
| Graph* host_graph = fbody->graph; |
| |
| // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse |
| // reachable from sink node so all nodes will be copied. |
| // TODO(b/77601805): consolidate copy graph functions. |
| FixupSourceAndSinkEdges(host_graph); |
| |
| // Copy all nodes. |
| std::map<const Node*, Node*> node_map; |
| if (pivot_node) { |
| node_map[host_graph->source_node()] = pivot_node; |
| } else { |
| node_map[host_graph->source_node()] = main_graph->source_node(); |
| } |
| node_map[host_graph->sink_node()] = main_graph->sink_node(); |
| Status s = Status::OK(); |
| auto copy_node_fn = [&](const Node* n) { |
| if (!s.ok()) { |
| return; |
| } |
| |
| Node* copy; |
| if (node_map.find(n) != node_map.end()) { |
| // Already copied this node. |
| copy = node_map.at(n); |
| } else { |
| // Copy the node. |
| NodeDef copy_def = n->def(); |
| copy = main_graph->AddNode(copy_def, &s); |
| if (!s.ok()) { |
| return; |
| } |
| node_map[n] = copy; |
| } |
| |
| // Only handle input edges. Output edges will be added later as its output |
| // nodes' input edges. |
| for (auto e : n->in_edges()) { |
| if (node_map.find(e->src()) == node_map.end()) { |
| s = errors::Internal("Cannot find node image for ", |
| e->src()->DebugString()); |
| return; |
| } |
| main_graph->AddEdge(node_map[e->src()], e->src_output(), copy, |
| e->dst_input()); |
| } |
| |
| // Add control edge from sequencer to XLA computation node. |
| if (copy->type_string() == "NoOp" && |
| HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) { |
| main_graph->AddControlEdge(copy, xla_computation_node); |
| } |
| }; |
| ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); |
| return s; |
| } |
| |
| // Rewrites shape inference graph for outside compilation: |
| // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from |
| // `host_graph`. Because we might still have outside compilation to outside |
| // compilation placeholder nodes in shape inference graph, which will prevent |
| // us from inferring XlaSendFromHost shape. But in `host_graph`, we already |
| // removed those placeholder nodes. |
| // 2) Remove control edges. |
| // 3) Prune nodes that are not useful for shape inference. |
| Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, |
| Graph* host_graph, Node* pivot_node, |
| FunctionLibraryDefinition* fld) { |
| // Use "0" as "_device_ordinal". It does not matter for shape inference. |
| AttrValue device_ordinal_attr; |
| device_ordinal_attr.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_attr; |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, &fbody)); |
| Graph* g = fbody->graph; |
| |
| // Find SendFromHost node. |
| Node* send_from_host = nullptr; |
| for (Node* n : g->nodes()) { |
| if (n->type_string() == "_XlaSendFromHost") { |
| send_from_host = n; |
| break; |
| } |
| } |
| if (!send_from_host) { |
| return errors::Internal("Shape inference graph ", |
| shape_inference_graph_name, |
| " does not have _XlaSendFromHost node."); |
| } |
| |
| // See if the SendFromHost node exists in `host_graph`. |
| Node* send_node_in_host_graph = nullptr; |
| for (Node* n : host_graph->nodes()) { |
| if (n->name() == send_from_host->name()) { |
| send_node_in_host_graph = n; |
| break; |
| } |
| } |
| if (send_node_in_host_graph) { |
| // This is an "top-level" outside compilation. Clear the graph, and copy |
| // SendFromHost and all its predecessors from `host_graph`. |
| std::vector<Node*> nodes; |
| for (Node* n : g->op_nodes()) { |
| nodes.push_back(n); |
| } |
| for (Node* n : nodes) { |
| g->RemoveNode(n); |
| } |
| Node* start_node = pivot_node ? pivot_node : host_graph->source_node(); |
| // Reverse DFS from send_from_host_main_graph, and stop at start_node. |
| struct Visit { |
| Node* n; |
| bool is_exiting; |
| }; |
| std::vector<Visit> stack{{send_node_in_host_graph, false}}; |
| std::map<Node*, Node*> node_map; |
| node_map[host_graph->source_node()] = g->source_node(); |
| while (!stack.empty()) { |
| Visit& curr = stack.back(); |
| if (curr.is_exiting) { |
| if (node_map.find(curr.n) == node_map.end()) { |
| Node* copy = g->CopyNode(curr.n); |
| if (curr.n != start_node) { |
| for (const Edge* e : curr.n->in_edges()) { |
| auto node_iter = node_map.find(e->src()); |
| if (node_iter == node_map.end()) { |
| return errors::Internal("Cannot find node image for ", |
| e->src()->DebugString()); |
| } |
| g->AddEdge(node_iter->second, e->src_output(), copy, |
| e->dst_input()); |
| } |
| } |
| node_map[curr.n] = copy; |
| } |
| stack.pop_back(); |
| } else { |
| curr.is_exiting = true; |
| if (curr.n != start_node) { |
| for (const Edge* e : curr.n->in_edges()) { |
| if (node_map.find(e->src()) != node_map.end()) { |
| continue; |
| } |
| stack.push_back({e->src(), false}); |
| } |
| } |
| } |
| } |
| |
| send_from_host = node_map[send_node_in_host_graph]; |
| } else { |
| // This is an outside compilation generated for If/While/gradient/etc. |
| // It will be enough for shape inference. Leave `g` unchanged. |
| } |
| |
| // Control edges are not useful for shape inference. Remove them. |
| for (auto e : g->edges()) { |
| if (e->IsControlEdge()) { |
| g->RemoveEdge(e); |
| } |
| } |
| |
| // Nodes that are not reverse reachable from SendFromHost are not useful for |
| // shape inference. Prune them. |
| PruneForReverseReachability(g, |
| std::unordered_set<const Node*>{send_from_host}); |
| |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile(shape_inference_graph_name, *g, fld); |
| } |
| |
| // Replace original shape inference graph. |
| FunctionDef fdef_replace; |
| TF_RETURN_IF_ERROR( |
| GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace)); |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(shape_inference_graph_name, fdef_replace)); |
| |
| return Status::OK(); |
| } |
| |
| // Builds XlaSendToHost node which sends cond predicate to host. |
| TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> BuildSendIfPredNode( |
| const string& name, const string& host_transfer_key, Node* pred_node, |
| Graph* g) { |
| NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); |
| send_pred_builder.Attr("Tinput", DT_BOOL); |
| send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); |
| send_pred_builder.Attr(kXlaTokenInputNodesAttrName, |
| std::vector<string>{kXlaTokenArgNodeName}); |
| send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); |
| NodeDef send_pred_def; |
| TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); |
| Status s; |
| Node* send_pred_node = g->AddNode(send_pred_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| g->AddEdge(pred_node, 0, send_pred_node, 0); |
| return send_pred_node; |
| } |
| |
| // Replaces key placeholder node with an _Arg node. |
| Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, |
| const string& func_name, |
| FunctionLibraryDefinition* fld) { |
| // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder |
| // value after rewriting. |
| AttrValue device_ordinal_attr; |
| device_ordinal_attr.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_attr; |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(func_name), |
| AttrSlice(&attrs), fld, &fbody)); |
| Graph* g = fbody->graph; |
| |
| // Find or create the key placeholder node. |
| Node* key_placeholder = nullptr; |
| for (Node* n : g->nodes()) { |
| if (IsKeyPlaceholderNode(*n)) { |
| key_placeholder = n; |
| break; |
| } |
| } |
| if (!key_placeholder) { |
| TF_ASSIGN_OR_RETURN(key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name, g)); |
| } |
| |
| // Build the _Arg node, and replace key placeholder node with it. |
| NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); |
| arg_builder.Attr("T", DT_STRING); |
| arg_builder.Attr("index", 0); |
| NodeDef arg_def; |
| TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); |
| TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); |
| |
| // Reset "_device_ordinal" to placeholder value. |
| TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); |
| |
| FunctionDef replace_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef( |
| *g, func_name, HostGraphControlRetMapping, &replace_fdef)); |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); |
| return Status::OK(); |
| } |
| |
| // Builds host side graph for If node. |
| TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const string& if_node_name, const string& host_transfer_key, |
| const string& host_graph_func_name, FunctionLibraryDefinition* fld, |
| const string& then_branch_host_func_name, |
| const string& else_branch_host_func_name) { |
| Graph host_graph(fld); |
| string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| |
| // Step 1: add key placeholder node. |
| TF_ASSIGN_OR_RETURN( |
| Node * key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); |
| |
| // Step 2: build XlaRecvAtHost node to recv predicate. |
| NodeDefBuilder recv_pred_builder( |
| absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); |
| recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL}); |
| recv_pred_builder.Attr("key", host_transfer_key); |
| recv_pred_builder.Attr("device_ordinal", device_ordinal_value); |
| recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); |
| recv_pred_builder.Attr(outside_compilation_attr_name, |
| outside_compilation_name); |
| recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); |
| recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); |
| NodeDef recv_pred_def; |
| TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); |
| Status s; |
| Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); |
| |
| // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key |
| // placeholder with an _Arg node. |
| TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( |
| xla_cluster_name, then_branch_host_func_name, fld)); |
| TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( |
| xla_cluster_name, else_branch_host_func_name, fld)); |
| |
| // Step 4: build If node to choose between `{then, else}_branch_host_graph`. |
| NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); |
| if_builder.Attr("Tcond", DT_BOOL); |
| if_builder.Attr("Tin", std::vector<DataType>{DT_STRING}); |
| if_builder.Attr("Tout", std::vector<DataType>{}); |
| NameAttrList host_then_branch, host_else_branch; |
| host_then_branch.set_name(then_branch_host_func_name); |
| (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value; |
| host_else_branch.set_name(else_branch_host_func_name); |
| (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value; |
| if_builder.Attr("then_branch", host_then_branch); |
| if_builder.Attr("else_branch", host_else_branch); |
| if_builder.Attr(kXlaHasHostTransferAttrName, true); |
| if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); |
| if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); |
| if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); |
| std::vector<NodeDefBuilder::NodeOut> if_inputs{ |
| {key_placeholder->name(), 0, DT_STRING}}; |
| if_builder.Input(if_inputs); |
| NodeDef if_def; |
| TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); |
| Node* if_node = host_graph.AddNode(if_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| host_graph.AddEdge(recv_pred_node, 0, if_node, 0); |
| host_graph.AddEdge(key_placeholder, 0, if_node, 1); |
| |
| // Convert `host_graph` to function. |
| FunctionDef oc_host_graph_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, |
| &oc_host_graph_fdef)); |
| if (fld->Find(host_graph_func_name)) { |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Rewrites loop cond to add a node which sends loop cond to host. |
| TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( |
| FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func, |
| const string& while_node_name, const string& host_transfer_key) { |
| // Instantiate the loop cond function. |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()), |
| AttrSlice(&loop_cond_func.attr()), |
| fld, &fbody)); |
| Graph* g = fbody->graph; |
| |
| // Find the _Retval node and the loop cond node. |
| Node* ret_node = nullptr; |
| for (Node* n : g->nodes()) { |
| if (n->type_string() == "_Retval") { |
| if (ret_node) { |
| return errors::Internal("Multiple return node for loop cond function ", |
| loop_cond_func.name(), ": ", |
| ret_node->DebugString(), " and ", |
| n->DebugString()); |
| } else { |
| ret_node = n; |
| } |
| } |
| } |
| if (!ret_node) { |
| return errors::Internal("No _Retval node for loop cond function ", |
| loop_cond_func.name()); |
| } |
| Node* loop_cond; |
| TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); |
| |
| // Build the XlaSendToHost node. |
| NodeDefBuilder send_loop_cond_builder( |
| absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); |
| send_loop_cond_builder.Attr("Tinput", DT_BOOL); |
| send_loop_cond_builder.Attr("key", |
| absl::StrCat(host_transfer_key, "_dtoh_0")); |
| send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, |
| std::vector<string>{kXlaTokenArgNodeName}); |
| send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); |
| NodeDef send_loop_cond_def; |
| TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); |
| Status s; |
| Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); |
| |
| // Replace original function. |
| FunctionDef replace_fdef; |
| TF_RETURN_IF_ERROR( |
| GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); |
| |
| return Status::OK(); |
| } |
| |
| // Rewrites while loop cond function for host. |
| Status RewriteHostWhileLoopCond( |
| const string& cond_host_func_name, const string& while_node_name, |
| const string& host_transfer_key, const string& xla_cluster_attr_name, |
| const string& xla_cluster_name, const string& outside_compilation_attr_name, |
| const string& outside_compilation_name, FunctionLibraryDefinition* fld) { |
| // Replace key placeholder node with _Arg node. |
| TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( |
| xla_cluster_name, cond_host_func_name, fld)); |
| |
| // Instantiate cond function. |
| AttrValue device_ordinal_temp_value; |
| device_ordinal_temp_value.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_temp_value; |
| std::unique_ptr<FunctionBody> cond_fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, &cond_fbody)); |
| Graph* cond_graph = cond_fbody->graph; |
| Node* key_arg = nullptr; |
| for (Node* n : cond_graph->nodes()) { |
| if (n->type_string() == "_Arg") { |
| key_arg = n; |
| } |
| } |
| if (!key_arg) { |
| return errors::Internal( |
| "No _Arg node found for host compute key in function ", |
| cond_host_func_name); |
| } |
| |
| // Add an XlaRecvAtHost node to use as cond function return value. |
| NodeDefBuilder recv_pred_builder( |
| absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); |
| recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL}); |
| recv_pred_builder.Attr("key", host_transfer_key); |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| recv_pred_builder.Attr("device_ordinal", device_ordinal_value); |
| recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); |
| recv_pred_builder.Attr(outside_compilation_attr_name, |
| outside_compilation_name); |
| recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); |
| recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); |
| NodeDef recv_pred_def; |
| TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); |
| Status s; |
| Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); |
| NodeDefBuilder ret_builder( |
| absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); |
| ret_builder.Attr("T", DT_BOOL); |
| ret_builder.Attr("index", 0); |
| ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); |
| NodeDef ret_def; |
| TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); |
| Node* ret_node = cond_graph->AddNode(ret_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); |
| |
| // Reset device_ordinal to placeholder value. |
| TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); |
| |
| // Replace original function. |
| FunctionDef cond_replace_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name, |
| HostGraphControlRetMapping, |
| &cond_replace_fdef)); |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); |
| |
| return Status::OK(); |
| } |
| |
| // Rewrites while loop body function for host. |
| Status RewriteHostWhileLoopBody( |
| const string& body_host_func_name, const string& while_node_name, |
| const string& host_transfer_key, const string& xla_cluster_attr_name, |
| const string& xla_cluster_name, const string& outside_compilation_attr_name, |
| const string& outside_compilation_name, FunctionLibraryDefinition* fld) { |
| // Replace key placeholder node with _Arg node. |
| TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( |
| xla_cluster_name, body_host_func_name, fld)); |
| |
| // Instantiate body function. |
| AttrValue device_ordinal_temp_value; |
| device_ordinal_temp_value.set_i(0); |
| protobuf::Map<string, AttrValue> attrs; |
| attrs["_device_ordinal"] = device_ordinal_temp_value; |
| std::unique_ptr<FunctionBody> body_fbody; |
| TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
| *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, &body_fbody)); |
| Graph* body_graph = body_fbody->graph; |
| Node* key_arg = nullptr; |
| for (Node* n : body_graph->nodes()) { |
| if (n->type_string() == "_Arg") { |
| key_arg = n; |
| } |
| } |
| if (!key_arg) { |
| return errors::Internal( |
| "No _Arg node found for host compute key in function ", |
| body_host_func_name); |
| } |
| |
| // Add a _Retval node to loop body. |
| NodeDefBuilder ret_builder( |
| absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); |
| ret_builder.Attr("T", DT_STRING); |
| ret_builder.Attr("index", 0); |
| ret_builder.Input(key_arg->name(), 0, DT_STRING); |
| NodeDef ret_def; |
| TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); |
| Status s; |
| Node* ret_node = body_graph->AddNode(ret_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| body_graph->AddEdge(key_arg, 0, ret_node, 0); |
| |
| // Reset device_ordinal to placeholder value. |
| TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); |
| |
| // Replace original function. |
| FunctionDef body_replace_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name, |
| HostGraphControlRetMapping, |
| &body_replace_fdef)); |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); |
| |
| return Status::OK(); |
| } |
| |
| // Builds host side graph for while node. |
| TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const string& while_node_name, const string& host_transfer_key, |
| const string& host_graph_func_name, FunctionLibraryDefinition* fld, |
| const string& cond_host_func_name, const string& body_host_func_name) { |
| Graph host_graph(fld); |
| string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); |
| |
| // Step 1: add key placeholder node. |
| TF_ASSIGN_OR_RETURN( |
| Node * key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); |
| |
| // Step 2: rewrite cond function. |
| TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( |
| cond_host_func_name, while_node_name, host_transfer_key, |
| xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, |
| outside_compilation_name, fld)); |
| |
| // Step 3: rewrite body function. |
| TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( |
| body_host_func_name, while_node_name, host_transfer_key, |
| xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, |
| outside_compilation_name, fld)); |
| |
| // Step 4: build While node. |
| NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), |
| "While"); |
| while_builder.Attr("T", std::vector<DataType>{DT_STRING}); |
| NameAttrList func; |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value; |
| func.set_name(cond_host_func_name); |
| while_builder.Attr("cond", func); |
| func.set_name(body_host_func_name); |
| while_builder.Attr("body", func); |
| while_builder.Attr(kXlaHasHostTransferAttrName, true); |
| while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); |
| while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); |
| // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th |
| // iteration. |
| while_builder.Attr("parallel_iterations", 1); |
| std::vector<NodeDefBuilder::NodeOut> while_inputs{ |
| {key_placeholder->name(), 0, DT_STRING}}; |
| while_builder.Input(while_inputs); |
| NodeDef while_def; |
| TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); |
| Status s; |
| Node* while_node = host_graph.AddNode(while_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| host_graph.AddEdge(key_placeholder, 0, while_node, 0); |
| |
| // Convert `host_graph` to function. |
| FunctionDef oc_host_graph_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, |
| &oc_host_graph_fdef)); |
| if (fld->Find(host_graph_func_name)) { |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Builds host graph for func call nodes. |
| Status BuildHostGraphForFuncCallNode( |
| const string& xla_cluster_attr_name, const string& xla_cluster_name, |
| const string& outside_compilation_attr_name, |
| const string& func_call_node_name, const string& func_call_host_func_name, |
| const string& host_graph_func_name, FunctionLibraryDefinition* fld) { |
| Graph host_graph(fld); |
| AttrValue device_ordinal_value; |
| device_ordinal_value.set_placeholder("_device_ordinal"); |
| |
| // Step 1: add key placeholder node. |
| TF_ASSIGN_OR_RETURN( |
| Node * key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); |
| |
| // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg |
| // node. |
| TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( |
| xla_cluster_name, func_call_host_func_name, fld)); |
| |
| // Step 3: build a function call node with `host_func_name`, with |
| // `key_placeholder` as input. |
| NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), |
| func_call_host_func_name, fld); |
| call_builder.Input(key_placeholder->name(), 0, DT_STRING); |
| call_builder.Attr("_device_ordinal", device_ordinal_value); |
| call_builder.Attr(kXlaHasHostTransferAttrName, true); |
| call_builder.Attr(xla_cluster_attr_name, xla_cluster_name); |
| call_builder.Attr(outside_compilation_attr_name, call_builder.node_name()); |
| NodeDef call_def; |
| TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); |
| Status s; |
| Node* call_node = host_graph.AddNode(call_def, &s); |
| TF_RETURN_IF_ERROR(s); |
| host_graph.AddEdge(key_placeholder, 0, call_node, 0); |
| |
| // Convert `host_graph` to function. |
| FunctionDef oc_host_graph_fdef; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, |
| HostGraphControlRetMapping, |
| &oc_host_graph_fdef)); |
| if (fld->Find(host_graph_func_name)) { |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const std::map<string, int>& host_compute_core, Graph* g, Node* n, |
| FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, |
| std::vector<string>* host_graphs, |
| std::vector<string>* shape_inference_graphs, |
| bool* has_outside_compilation) { |
| bool func_has_outside_compilation = false; |
| NameAttrList func; |
| if (fld->Contains(n->type_string())) { |
| func.set_name(n->type_string()); |
| typedef protobuf::Map<string, AttrValue> AttrMap; |
| *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); |
| } else if (n->IsPartitionedCall()) { |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func)); |
| } else { |
| TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp); |
| func.set_name(FunctionLibraryDefinition::kGradientOp); |
| *func.mutable_attr() = n->def().attr(); |
| } |
| string new_func_name = absl::StrCat(n->name(), "_oc"); |
| string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| func, new_func_name, host_func_name, host_compute_core, flr, fld, |
| shape_inference_graphs, &func_has_outside_compilation)); |
| |
| // If the function call does not have outside compilation, nothing to do. |
| if (!func_has_outside_compilation) { |
| return Status::OK(); |
| } |
| |
| *has_outside_compilation = true; |
| |
| // Change `n` to call the new function directly. |
| auto replace_builder = |
| absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld); |
| std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs()); |
| for (const Edge* e : n->in_edges()) { |
| if (e->IsControlEdge()) { |
| continue; |
| } |
| |
| TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size()); |
| inputs[e->dst_input()] = |
| NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(), |
| e->src()->output_type(e->src_output())}; |
| } |
| for (const auto& input : inputs) { |
| replace_builder->Input(input); |
| } |
| for (const auto& attr : n->attrs()) { |
| replace_builder->Attr(attr.first, attr.second); |
| } |
| auto replace_def = absl::make_unique<NodeDef>(); |
| TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get())); |
| TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def)); |
| replace->AddAttr(kXlaTokenInputNodesAttrName, |
| std::vector<string>{kXlaTokenArgNodeName}); |
| |
| // Build host side graph for the function call. |
| string oc_host_graph_name = |
| absl::StrCat("oc_func_host_graph_", replace->name()); |
| TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode( |
| xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, |
| replace->name(), host_func_name, oc_host_graph_name, fld)); |
| |
| // Record the host graph. |
| host_graphs->push_back(oc_host_graph_name); |
| |
| return Status::OK(); |
| } |
| |
| Status ExtractOutsideCompilationForIfNode( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const std::map<string, int>& host_compute_core, Graph* g, Node* n, |
| FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, |
| std::vector<string>* host_graphs, |
| std::vector<string>* shape_inference_graphs, |
| bool* has_outside_compilation) { |
| // Instantiate "then_branch" and "else_branch". |
| NameAttrList then_branch, else_branch; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); |
| |
| // Extract outside compilation for then_branch and else_branch. |
| bool then_branch_has_outside_compilation = false; |
| bool else_branch_has_outside_compilation = false; |
| string then_branch_host_func_name = |
| absl::StrCat("oc_then_branch_host_if_", n->name()), |
| else_branch_host_func_name = |
| absl::StrCat("oc_else_branch_host_if_", n->name()); |
| string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), |
| else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| then_branch, then_branch_xla_func_name, then_branch_host_func_name, |
| host_compute_core, flr, fld, shape_inference_graphs, |
| &then_branch_has_outside_compilation)); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| else_branch, else_branch_xla_func_name, else_branch_host_func_name, |
| host_compute_core, flr, fld, shape_inference_graphs, |
| &else_branch_has_outside_compilation)); |
| |
| // If then/else branch do not have outside compilation, nothing to do. |
| if (!then_branch_has_outside_compilation && |
| !else_branch_has_outside_compilation) { |
| return Status::OK(); |
| } |
| |
| *has_outside_compilation = true; |
| |
| // Change If node to call the new functions. |
| then_branch.set_name(then_branch_xla_func_name); |
| n->ClearAttr("then_branch"); |
| n->AddAttr("then_branch", then_branch); |
| else_branch.set_name(else_branch_xla_func_name); |
| n->ClearAttr("else_branch"); |
| n->AddAttr("else_branch", else_branch); |
| |
| string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); |
| |
| // XLA computation: add a SendToHost node to send cond predicate. |
| Node* pred_node; |
| TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); |
| TF_ASSIGN_OR_RETURN( |
| Node * send_pred_node, |
| BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), |
| host_transfer_key, pred_node, g)); |
| n->AddAttr(kXlaTokenInputNodesAttrName, |
| std::vector<string>{send_pred_node->name()}); |
| |
| // Add a control edge from `send_pred_node` to If node, so XlaCompiler will |
| // visit If node after `send_pred_node`, thus the token output for |
| // `send_pred_node` has been generated. |
| g->AddControlEdge(send_pred_node, n); |
| |
| // Build host side graph for the "If" node. |
| string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); |
| TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| n->name(), host_transfer_key, oc_host_graph_name, fld, |
| then_branch_host_func_name, else_branch_host_func_name)); |
| host_graphs->push_back(oc_host_graph_name); |
| |
| return Status::OK(); |
| } |
| |
| Status ExtractOutsideCompilationForWhileNode( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const std::map<string, int>& host_compute_core, Graph* g, Node* n, |
| FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, |
| std::vector<string>* host_graphs, |
| std::vector<string>* shape_inference_graphs, |
| bool* has_outside_compilation) { |
| // Instantiate "cond" and "body". |
| NameAttrList cond, body; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); |
| |
| // Extract outside compilation for cond and body. |
| bool cond_has_outside_compilation = false; |
| bool body_has_outside_compilation = false; |
| string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), |
| body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); |
| string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), |
| body_xla_func_name = absl::StrCat(body.name(), "_oc"); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, |
| fld, shape_inference_graphs, &cond_has_outside_compilation)); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| body, body_xla_func_name, body_host_func_name, host_compute_core, flr, |
| fld, shape_inference_graphs, &body_has_outside_compilation)); |
| |
| // If cond/body do not have outside compilation, nothing to do. |
| if (!cond_has_outside_compilation && !body_has_outside_compilation) { |
| return Status::OK(); |
| } |
| |
| *has_outside_compilation = true; |
| |
| // Change While node to call the new functions. |
| cond.set_name(cond_xla_func_name); |
| n->ClearAttr("cond"); |
| n->AddAttr("cond", cond); |
| body.set_name(body_xla_func_name); |
| n->ClearAttr("body"); |
| n->AddAttr("body", body); |
| |
| string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); |
| |
| // XLA computation: rewrite cond function to add a SendToHost node to send |
| // loop predicate. |
| TF_RETURN_IF_ERROR( |
| AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); |
| n->AddAttr(kXlaTokenInputNodesAttrName, |
| std::vector<string>{kXlaTokenArgNodeName}); |
| |
| // Build host side graph for the "While" node. |
| string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); |
| TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| n->name(), host_transfer_key, oc_host_graph_name, fld, |
| cond_host_func_name, body_host_func_name)); |
| host_graphs->push_back(oc_host_graph_name); |
| |
| return Status::OK(); |
| } |
| |
| Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( |
| Graph* g, const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, |
| FunctionLibraryDefinition* fld, std::vector<string>* host_graphs, |
| std::vector<string>* shape_inference_graphs, |
| bool* has_outside_compilation) { |
| std::vector<Node*> if_nodes, while_nodes, func_call_nodes; |
| for (Node* n : g->nodes()) { |
| if (n->IsIfNode()) { |
| if_nodes.push_back(n); |
| } else if (n->IsWhileNode()) { |
| while_nodes.push_back(n); |
| } else if (IsFunctionCall(*fld, *n)) { |
| func_call_nodes.push_back(n); |
| } |
| } |
| |
| for (Node* n : func_call_nodes) { |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, |
| has_outside_compilation)); |
| } |
| |
| for (Node* n : if_nodes) { |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, |
| has_outside_compilation)); |
| } |
| |
| for (Node* n : while_nodes) { |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, |
| has_outside_compilation)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| Status RewriteOutsideCompilationSubgraphFn::operator()( |
| const std::vector<OutputTensor>& arg_source_tensors, |
| std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation, |
| std::vector<int>* output_permutation, NodeDef* node_def) { |
| string old_name = node_def->op(); |
| string new_name = |
| absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name); |
| node_def->set_op(new_name); |
| node_def->set_name(new_name); |
| |
| // Later we will run PruneForReverseReachability(), so make sure all original |
| // nodes are reachable from sink node and won't be removed. |
| FixupSourceAndSinkEdges(graph->get()); |
| |
| // Step 1: create a key placeholder node. |
| TF_ASSIGN_OR_RETURN( |
| Node * key_placeholder, |
| AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get())); |
| |
| // Step 2: build RecvAtHost node, and replace all _Arg nodes with it. |
| std::vector<DataType> recv_at_host_dtypes; |
| TF_ASSIGN_OR_RETURN( |
| Node * recv_at_host_node, |
| ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name, |
| &recv_at_host_dtypes, key_placeholder)); |
| |
| // Step 3: build SendFromHost node, and replace all _Retval nodes with it. |
| std::vector<DataType> send_from_host_dtypes; |
| TF_ASSIGN_OR_RETURN( |
| Node * send_from_host_node, |
| ReplaceRetNodesWithSendFromHostNode( |
| graph->get(), new_name, &send_from_host_dtypes, key_placeholder)); |
| |
| // Step 4: add XLA cluster and outside compilation attr. |
| for (Node* n : (*graph)->nodes()) { |
| if (IsKeyPlaceholderNode(*n)) { |
| continue; |
| } |
| |
| n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_); |
| n->AddAttr(outside_compilation_attr_name_, old_name); |
| } |
| |
| // Check whether we have all input shapes for XlaSendFromHost. If we do, we |
| // will set `shapes` attr for the call node; otherwise we will save the |
| // shape inference graph and set `shape_inference_graph` for the call node. |
| absl::optional<std::vector<PartialTensorShape>> shapes = |
| GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node); |
| for (Node* n : (*graph)->nodes()) { |
| n->ClearAttr(kXlaInferredShapesAttrName); |
| } |
| |
| // Step 5: add control edges for originally XLA <-> outside compilation |
| // control edges. |
| for (Node* n : (*graph)->nodes()) { |
| if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) { |
| (*graph)->AddControlEdge(n, send_from_host_node); |
| n->ClearAttr(kXlaConnectedToXlaComputationAttrName); |
| } |
| if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) { |
| (*graph)->AddControlEdge(recv_at_host_node, n); |
| n->ClearAttr(kXlaConnectedFromXlaComputationAttrName); |
| } |
| } |
| |
| // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune |
| // them if necessary. |
| // - RecvAtHost should be pruned iff it has no output data/control edges. If |
| // it has any output edge, it will be reverse reachable from sink node. We |
| // don't need to do anything special. |
| // - SendFromHost should be pruned iff it has no input data/control edges. If |
| // it has input edges other than key_placeholder, we connect it to sink |
| // node so it won't be pruned. |
| // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned. |
| // We don't need to do anything special. |
| if (send_from_host_node->in_edges().size() > 1) { |
| (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node()); |
| } |
| PruneForReverseReachability( |
| graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()}); |
| |
| // Step 7: add necessary attributes to function call node, so we can replace |
| // it with HostCompute node later. |
| AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); |
| if (shapes) { |
| NameAttrList shape_inference_graph; |
| AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); |
| AddNodeAttr("shapes", *shapes, node_def); |
| } else { |
| string shape_inference_func_name = |
| absl::StrCat("_outside_compilation_shape_inference_", new_name); |
| NameAttrList shape_inference_graph; |
| shape_inference_graph.set_name(shape_inference_func_name); |
| AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); |
| AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def); |
| } |
| AddNodeAttr("ancestors", std::vector<string>{}, node_def); |
| AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def); |
| AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); |
| AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); |
| |
| return Status::OK(); |
| } |
| |
| Status ExtractOutsideCompilationForFunction( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, const string& xla_cluster_name, |
| const NameAttrList& func_name_attrs, const string& new_func_name, |
| const string& host_graph_func_name, |
| const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, |
| FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs, |
| bool* has_outside_compilation) { |
| // Convert the function to graph. |
| const string& func_name = func_name_attrs.name(); |
| FunctionLibraryRuntime::Handle handle; |
| TF_RETURN_IF_ERROR( |
| flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); |
| Status ret_status = Status::OK(); |
| auto cleanup_handle = gtl::MakeCleanup([&]() { |
| auto s = flr->ReleaseHandle(handle); |
| if (!s.ok()) { |
| ret_status.Update(s); |
| } |
| }); |
| const FunctionBody* fbody = flr->GetFunctionBody(handle); |
| |
| // Check if we have outside compilation nodes. |
| *has_outside_compilation = false; |
| for (Node* n : fbody->graph->nodes()) { |
| if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { |
| *has_outside_compilation = true; |
| break; |
| } |
| } |
| // We cannot early return here, because we might have outside compilation in |
| // If/While function body. |
| |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile( |
| absl::StrCat("extract_outside_compilation_for_func_before_", func_name), |
| *fbody->graph, fld); |
| } |
| |
| // Find dependencies between outside compilation clusters. |
| TF_ASSIGN_OR_RETURN(auto cluster_deps, |
| OutsideCompilationClusterDependencies( |
| fbody->graph, outside_compilation_attr_name)); |
| |
| // Preprocess edges between different outside compilations. They will be |
| // restored in `ConstructHostGraph()`. |
| TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( |
| fbody->graph, outside_compilation_attr_name)); |
| |
| // Encapsulate outside_compilation cluster into function call node. |
| std::unique_ptr<Graph> graph_out; |
| auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| new_func_name); |
| TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( |
| outside_compilation_attr_name, *fbody->graph, *rewrite_fn, |
| /*reuse_existing_functions=*/true, &graph_out, fld)); |
| |
| // Replace outside_compilation function nodes with HostCompute ops. |
| std::vector<Node*> outside_compilation_nodes; |
| std::vector<string> outside_compilation_host_graphs; |
| std::vector<string> shape_inference_graphs_to_rewrite; |
| for (Node* n : graph_out->nodes()) { |
| if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) { |
| outside_compilation_nodes.push_back(n); |
| outside_compilation_host_graphs.push_back(n->name()); |
| |
| // If we could not infer shapes for XlaSendFromHost inputs statically, we |
| // will set the "shape_inference_graph" attribute. In that case, copy |
| // outside compilation subgraph as shape inference graph in `fld`. |
| auto shape_inference_graph = absl::make_unique<NameAttrList>(); |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", |
| shape_inference_graph.get())); |
| if (!shape_inference_graph->name().empty()) { |
| shape_inference_graphs->push_back(shape_inference_graph->name()); |
| shape_inference_graphs_to_rewrite.push_back( |
| shape_inference_graph->name()); |
| |
| const FunctionDef* xla_fdef = fld->Find(n->name()); |
| if (!xla_fdef) { |
| return errors::Internal("Cannot find XLA function ", n->name()); |
| } |
| auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef); |
| shape_inference_fdef->mutable_signature()->set_name( |
| shape_inference_graph->name()); |
| if (fld->Find(shape_inference_graph->name())) { |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(), |
| *shape_inference_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef)); |
| } |
| } |
| } |
| } |
| std::map<string, Node*> host_compute_nodes; |
| for (Node* n : outside_compilation_nodes) { |
| TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); |
| auto host_compute_node_or = ReplaceOutsideCompilationCallNode( |
| graph_out.get(), n, host_compute_core, *cluster_deps); |
| TF_RETURN_IF_ERROR(host_compute_node_or.status()); |
| Node* host_compute_node = host_compute_node_or.ValueOrDie(); |
| host_compute_nodes[host_compute_node->name()] = host_compute_node; |
| } |
| // For XlaHostCompute nodes with dependencies, add control edges between them |
| // so XlaCompiler can handle them in correct order. |
| for (auto iter : host_compute_nodes) { |
| Node* host_compute_node = iter.second; |
| std::vector<string> token_input_node_names; |
| TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(), |
| kXlaTokenInputNodesAttrName, |
| &token_input_node_names)); |
| for (const string& node_name : token_input_node_names) { |
| if (node_name == kXlaTokenArgNodeName) { |
| continue; |
| } |
| |
| auto iter = host_compute_nodes.find(node_name); |
| TF_RET_CHECK(iter != host_compute_nodes.end()); |
| graph_out->AddControlEdge(iter->second, host_compute_node); |
| } |
| } |
| |
| // Handle nodes with associated functions. |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( |
| graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, |
| xla_cluster_name, host_compute_core, flr, fld, |
| &outside_compilation_host_graphs, shape_inference_graphs, |
| has_outside_compilation)); |
| |
| // Construct host graph. |
| std::unique_ptr<Graph> host_graph; |
| TF_RETURN_IF_ERROR( |
| ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, |
| outside_compilation_host_graphs, fld, &host_graph)); |
| auto host_graph_fdef = absl::make_unique<FunctionDef>(); |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name, |
| HostGraphControlRetMapping, |
| host_graph_fdef.get())); |
| if (fld->Find(host_graph_func_name)) { |
| TF_RETURN_IF_ERROR( |
| fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef)); |
| } |
| |
| // Shape inference graphs might contain Placeholder nodes for outside |
| // compilation to outside compilation edges. Rewrite shape inference graphs |
| // to remove such nodes. |
| for (const string& shape_inference_graph : |
| shape_inference_graphs_to_rewrite) { |
| TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph, |
| host_graph.get(), |
| /*pivot_node=*/nullptr, fld)); |
| } |
| |
| // Remove the outside compilation graphs from function library. |
| for (const string& func : outside_compilation_host_graphs) { |
| TF_RETURN_IF_ERROR(fld->RemoveFunction(func)); |
| } |
| |
| // Replace original function. |
| auto updated_fdef = absl::make_unique<FunctionDef>(); |
| TF_RETURN_IF_ERROR( |
| GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get())); |
| const FunctionDef* original_fdef = fld->Find(func_name); |
| if (original_fdef) { |
| for (const auto& attr : original_fdef->attr()) { |
| (*updated_fdef->mutable_attr())[attr.first] = attr.second; |
| } |
| } |
| if (fld->Find(new_func_name)) { |
| TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef)); |
| } else { |
| TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef)); |
| } |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile( |
| absl::StrCat("extract_outside_compilation_for_func_after_", func_name), |
| *graph_out, fld); |
| } |
| |
| return ret_status; |
| } |
| |
| Status ExtractOutsideCompilation( |
| const string& xla_cluster_attr_name, |
| const string& outside_compilation_attr_name, |
| const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, |
| FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, |
| bool* modified) { |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile("extract_outside_compilation_before", *g, fld); |
| } |
| |
| *modified = false; |
| auto node_name_index = g->BuildNodeNameIndex(); |
| for (auto& iter : clusters) { |
| string xla_cluster_name = iter.first; |
| Node* n = iter.second.node; |
| auto const& func_name_attrs = iter.second.func_name_attrs; |
| auto const& host_compute_core = iter.second.host_compute_core; |
| |
| std::vector<string> shape_inference_graphs; |
| bool has_outside_compilation; |
| string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); |
| TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( |
| xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, |
| func_name_attrs, func_name_attrs.name(), host_graph_func_name, |
| host_compute_core, flr, fld, &shape_inference_graphs, |
| &has_outside_compilation)); |
| *modified |= has_outside_compilation; |
| |
| string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); |
| Node* pivot_node = node_name_index[pivot_name]; |
| TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph( |
| g, fld, host_graph_func_name, n, pivot_node)); |
| |
| TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); |
| |
| for (auto shape_inference_graph_name : shape_inference_graphs) { |
| TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name, |
| g, pivot_node, fld)); |
| } |
| } |
| |
| if (VLOG_IS_ON(4)) { |
| DumpGraphToFile("extract_outside_compilation_after", *g, fld); |
| } |
| return Status::OK(); |
| } |
| |
| } // namespace tensorflow |