Only perform outside compilation related logic if there are outside compilation nodes.
PiperOrigin-RevId: 273332365
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index b35e08fb1..7aa7f17 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -1010,8 +1010,10 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_attr;
std::unique_ptr<FunctionBody> host_fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(host_func), AttrSlice(&attrs), fld, &host_fbody));
+ const FunctionDef* host_fdef = fld->Find(host_func);
+ TF_RET_CHECK(host_fdef);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
+ fld, &host_fbody));
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
// reachable from sink node so all nodes will be copied.
@@ -1121,7 +1123,9 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_attr;
std::unique_ptr<FunctionBody> fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(host_graph_func_name),
+ const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
+ TF_RET_CHECK(host_graph_func);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
AttrSlice(&attrs), fld, &fbody));
Graph* host_graph = fbody->graph;
@@ -1197,8 +1201,11 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_attr;
std::unique_ptr<FunctionBody> fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, &fbody));
+ const FunctionDef* shape_inference_graph =
+ fld->Find(shape_inference_graph_name);
+ TF_RET_CHECK(shape_inference_graph);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
+ AttrSlice(&attrs), fld, &fbody));
Graph* g = fbody->graph;
// Find SendFromHost node.
@@ -1336,8 +1343,9 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_attr;
std::unique_ptr<FunctionBody> fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(func_name),
- AttrSlice(&attrs), fld, &fbody));
+ const FunctionDef* func = fld->Find(func_name);
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
Graph* g = fbody->graph;
// Find or create the key placeholder node.
@@ -1460,9 +1468,10 @@
const string& while_node_name, const string& host_transfer_key) {
// Instantiate the loop cond function.
std::unique_ptr<FunctionBody> fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()),
- AttrSlice(&loop_cond_func.attr()),
- fld, &fbody));
+ const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name());
+ TF_RET_CHECK(loop_cond_fdef);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody));
Graph* g = fbody->graph;
// Find the _Retval node and the loop cond node.
@@ -1527,8 +1536,9 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_temp_value;
std::unique_ptr<FunctionBody> cond_fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, &cond_fbody));
+ const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
+ fld, &cond_fbody));
Graph* cond_graph = cond_fbody->graph;
Node* key_arg = nullptr;
for (Node* n : cond_graph->nodes()) {
@@ -1602,8 +1612,10 @@
protobuf::Map<string, AttrValue> attrs;
attrs["_device_ordinal"] = device_ordinal_temp_value;
std::unique_ptr<FunctionBody> body_fbody;
- TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, &body_fbody));
+ const FunctionDef* body_host_func = fld->Find(body_host_func_name);
+ TF_RET_CHECK(body_host_func);
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
+ fld, &body_fbody));
Graph* body_graph = body_fbody->graph;
Node* key_arg = nullptr;
for (Node* n : body_graph->nodes()) {
@@ -1880,12 +1892,16 @@
*has_outside_compilation = true;
// Change If node to call the new functions.
- then_branch.set_name(then_branch_xla_func_name);
- n->ClearAttr("then_branch");
- n->AddAttr("then_branch", then_branch);
- else_branch.set_name(else_branch_xla_func_name);
- n->ClearAttr("else_branch");
- n->AddAttr("else_branch", else_branch);
+ if (then_branch_has_outside_compilation) {
+ then_branch.set_name(then_branch_xla_func_name);
+ n->ClearAttr("then_branch");
+ n->AddAttr("then_branch", then_branch);
+ }
+ if (else_branch_has_outside_compilation) {
+ else_branch.set_name(else_branch_xla_func_name);
+ n->ClearAttr("else_branch");
+ n->AddAttr("else_branch", else_branch);
+ }
string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
@@ -1905,6 +1921,43 @@
g->AddControlEdge(send_pred_node, n);
// Build host side graph for the "If" node.
+ // If then/else branch does not have outside compilation, we won't build host
+ // graph for the branch. But here we need a host graph for both branches, so
+ // we need to create a no-op host graph.
+ if (!then_branch_has_outside_compilation) {
+ std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
+ std::vector<string> then_branch_host_graphs;
+ TF_RETURN_IF_ERROR(ConstructHostGraph(
+ xla_cluster_name, outside_compilation_attr_name,
+ then_branch_host_graphs, fld, &then_branch_host_graph));
+ FunctionDef then_branch_host_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
+ then_branch_host_func_name,
+ &then_branch_host_fdef));
+ if (fld->Find(then_branch_host_func_name)) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
+ then_branch_host_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
+ }
+ }
+ if (!else_branch_has_outside_compilation) {
+ std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
+ std::vector<string> else_branch_host_graphs;
+ TF_RETURN_IF_ERROR(ConstructHostGraph(
+ xla_cluster_name, outside_compilation_attr_name,
+ else_branch_host_graphs, fld, &else_branch_host_graph));
+ FunctionDef else_branch_host_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
+ else_branch_host_func_name,
+ &else_branch_host_fdef));
+ if (fld->Find(else_branch_host_func_name)) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
+ else_branch_host_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
+ }
+ }
string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
@@ -1952,12 +2005,16 @@
*has_outside_compilation = true;
// Change While node to call the new functions.
- cond.set_name(cond_xla_func_name);
- n->ClearAttr("cond");
- n->AddAttr("cond", cond);
- body.set_name(body_xla_func_name);
- n->ClearAttr("body");
- n->AddAttr("body", body);
+ if (cond_has_outside_compilation) {
+ cond.set_name(cond_xla_func_name);
+ n->ClearAttr("cond");
+ n->AddAttr("cond", cond);
+ }
+ if (body_has_outside_compilation) {
+ body.set_name(body_xla_func_name);
+ n->ClearAttr("body");
+ n->AddAttr("body", body);
+ }
string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
@@ -1969,6 +2026,38 @@
std::vector<string>{kXlaTokenArgNodeName});
// Build host side graph for the "While" node.
+ if (!cond_has_outside_compilation) {
+ std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
+ std::vector<string> host_graphs;
+ TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
+ outside_compilation_attr_name,
+ host_graphs, fld, &cond_host_graph));
+ FunctionDef cond_host_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
+ &cond_host_fdef));
+ if (fld->Find(cond_host_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
+ }
+ }
+ if (!body_has_outside_compilation) {
+ std::unique_ptr<Graph> body_host_graph(new Graph(fld));
+ std::vector<string> host_graphs;
+ TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
+ outside_compilation_attr_name,
+ host_graphs, fld, &body_host_graph));
+ FunctionDef body_host_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
+ &body_host_fdef));
+ if (fld->Find(body_host_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(body_host_func_name, body_host_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
+ }
+ }
string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
@@ -2166,146 +2255,151 @@
*fbody->graph, fld);
}
- // Find dependencies between outside compilation clusters.
- TF_ASSIGN_OR_RETURN(auto cluster_deps,
- OutsideCompilationClusterDependencies(
- fbody->graph, outside_compilation_attr_name));
-
- // Preprocess edges between different outside compilations. They will be
- // restored in `ConstructHostGraph()`.
- TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
- fbody->graph, outside_compilation_attr_name));
-
- // Encapsulate outside_compilation cluster into function call node.
std::unique_ptr<Graph> graph_out;
- auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
- xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
- new_func_name);
- TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
- outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
- /*reuse_existing_functions=*/true, &graph_out, fld));
-
- // Replace outside_compilation function nodes with HostCompute ops.
- std::vector<Node*> outside_compilation_nodes;
std::vector<string> outside_compilation_host_graphs;
std::vector<string> shape_inference_graphs_to_rewrite;
- for (Node* n : graph_out->nodes()) {
- if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
- outside_compilation_nodes.push_back(n);
- outside_compilation_host_graphs.push_back(n->name());
+ if (*has_outside_compilation) {
+ // Find dependencies between outside compilation clusters.
+ TF_ASSIGN_OR_RETURN(auto cluster_deps,
+ OutsideCompilationClusterDependencies(
+ fbody->graph, outside_compilation_attr_name));
- // If we could not infer shapes for XlaSendFromHost inputs statically, we
- // will set the "shape_inference_graph" attribute. In that case, copy
- // outside compilation subgraph as shape inference graph in `fld`.
- auto shape_inference_graph = absl::make_unique<NameAttrList>();
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
- shape_inference_graph.get()));
- if (!shape_inference_graph->name().empty()) {
- shape_inference_graphs->push_back(shape_inference_graph->name());
- shape_inference_graphs_to_rewrite.push_back(
- shape_inference_graph->name());
+ // Preprocess edges between different outside compilations. They will be
+ // restored in `ConstructHostGraph()`.
+ TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
+ fbody->graph, outside_compilation_attr_name));
- const FunctionDef* xla_fdef = fld->Find(n->name());
- if (!xla_fdef) {
- return errors::Internal("Cannot find XLA function ", n->name());
- }
- auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
- shape_inference_fdef->mutable_signature()->set_name(
- shape_inference_graph->name());
- if (fld->Find(shape_inference_graph->name())) {
- TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(),
- *shape_inference_fdef));
- } else {
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
+ // Encapsulate outside_compilation cluster into function call node.
+ auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ new_func_name);
+ TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
+ outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
+ /*reuse_existing_functions=*/true, &graph_out, fld));
+
+ // Replace outside_compilation function nodes with HostCompute ops.
+ std::vector<Node*> outside_compilation_nodes;
+ for (Node* n : graph_out->nodes()) {
+ if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
+ outside_compilation_nodes.push_back(n);
+ outside_compilation_host_graphs.push_back(n->name());
+
+ // If we could not infer shapes for XlaSendFromHost inputs statically,
+ // we will set the "shape_inference_graph" attribute. In that case, copy
+ // outside compilation subgraph as shape inference graph in `fld`.
+ auto shape_inference_graph = absl::make_unique<NameAttrList>();
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
+ shape_inference_graph.get()));
+ if (!shape_inference_graph->name().empty()) {
+ shape_inference_graphs->push_back(shape_inference_graph->name());
+ shape_inference_graphs_to_rewrite.push_back(
+ shape_inference_graph->name());
+
+ const FunctionDef* xla_fdef = fld->Find(n->name());
+ if (!xla_fdef) {
+ return errors::Internal("Cannot find XLA function ", n->name());
+ }
+ auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
+ shape_inference_fdef->mutable_signature()->set_name(
+ shape_inference_graph->name());
+ if (fld->Find(shape_inference_graph->name())) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(
+ shape_inference_graph->name(), *shape_inference_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
+ }
}
}
}
- }
- std::map<string, Node*> host_compute_nodes;
- for (Node* n : outside_compilation_nodes) {
- TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
- auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
- graph_out.get(), n, host_compute_core, *cluster_deps);
- TF_RETURN_IF_ERROR(host_compute_node_or.status());
- Node* host_compute_node = host_compute_node_or.ValueOrDie();
- host_compute_nodes[host_compute_node->name()] = host_compute_node;
- }
- // For XlaHostCompute nodes with dependencies, add control edges between them
- // so XlaCompiler can handle them in correct order.
- for (auto iter : host_compute_nodes) {
- Node* host_compute_node = iter.second;
- std::vector<string> token_input_node_names;
- TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
- kXlaTokenInputNodesAttrName,
- &token_input_node_names));
- for (const string& node_name : token_input_node_names) {
- if (node_name == kXlaTokenArgNodeName) {
- continue;
- }
+ std::map<string, Node*> host_compute_nodes;
+ for (Node* n : outside_compilation_nodes) {
+ TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
+ auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
+ graph_out.get(), n, host_compute_core, *cluster_deps);
+ TF_RETURN_IF_ERROR(host_compute_node_or.status());
+ Node* host_compute_node = host_compute_node_or.ValueOrDie();
+ host_compute_nodes[host_compute_node->name()] = host_compute_node;
+ }
+ // For XlaHostCompute nodes with dependencies, add control edges between
+ // them so XlaCompiler can handle them in correct order.
+ for (auto iter : host_compute_nodes) {
+ Node* host_compute_node = iter.second;
+ std::vector<string> token_input_node_names;
+ TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
+ kXlaTokenInputNodesAttrName,
+ &token_input_node_names));
+ for (const string& node_name : token_input_node_names) {
+ if (node_name == kXlaTokenArgNodeName) {
+ continue;
+ }
- auto iter = host_compute_nodes.find(node_name);
- TF_RET_CHECK(iter != host_compute_nodes.end());
- graph_out->AddControlEdge(iter->second, host_compute_node);
+ auto iter = host_compute_nodes.find(node_name);
+ TF_RET_CHECK(iter != host_compute_nodes.end());
+ graph_out->AddControlEdge(iter->second, host_compute_node);
+ }
}
}
// Handle nodes with associated functions.
+ Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
- graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name,
- xla_cluster_name, host_compute_core, flr, fld,
- &outside_compilation_host_graphs, shape_inference_graphs,
- has_outside_compilation));
+ g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ host_compute_core, flr, fld, &outside_compilation_host_graphs,
+ shape_inference_graphs, has_outside_compilation));
- // Construct host graph.
- std::unique_ptr<Graph> host_graph;
- TF_RETURN_IF_ERROR(
- ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
- outside_compilation_host_graphs, fld, &host_graph));
- auto host_graph_fdef = absl::make_unique<FunctionDef>();
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
- HostGraphControlRetMapping,
- host_graph_fdef.get()));
- if (fld->Find(host_graph_func_name)) {
+ if (*has_outside_compilation) {
+ // Construct host graph.
+ std::unique_ptr<Graph> host_graph;
TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
- } else {
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
- }
-
- // Shape inference graphs might contain Placeholder nodes for outside
- // compilation to outside compilation edges. Rewrite shape inference graphs
- // to remove such nodes.
- for (const string& shape_inference_graph :
- shape_inference_graphs_to_rewrite) {
- TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph,
- host_graph.get(),
- /*pivot_node=*/nullptr, fld));
- }
-
- // Remove the outside compilation graphs from function library.
- for (const string& func : outside_compilation_host_graphs) {
- TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
- }
-
- // Replace original function.
- auto updated_fdef = absl::make_unique<FunctionDef>();
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get()));
- const FunctionDef* original_fdef = fld->Find(func_name);
- if (original_fdef) {
- for (const auto& attr : original_fdef->attr()) {
- (*updated_fdef->mutable_attr())[attr.first] = attr.second;
+ ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_host_graphs, fld, &host_graph));
+ auto host_graph_fdef = absl::make_unique<FunctionDef>();
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
+ HostGraphControlRetMapping,
+ host_graph_fdef.get()));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
}
- }
- if (fld->Find(new_func_name)) {
- TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
- } else {
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
- }
- if (VLOG_IS_ON(4)) {
- DumpGraphToFile(
- absl::StrCat("extract_outside_compilation_for_func_after_", func_name),
- *graph_out, fld);
+
+ // Shape inference graphs might contain Placeholder nodes for outside
+ // compilation to outside compilation edges. Rewrite shape inference graphs
+ // to remove such nodes.
+ for (const string& shape_inference_graph :
+ shape_inference_graphs_to_rewrite) {
+ TF_RETURN_IF_ERROR(
+ RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
+ /*pivot_node=*/nullptr, fld));
+ }
+
+ // Remove the outside compilation graphs from function library.
+ for (const string& func : outside_compilation_host_graphs) {
+ TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
+ }
+
+ // Replace original function.
+ auto updated_fdef = absl::make_unique<FunctionDef>();
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
+ const FunctionDef* original_fdef = fld->Find(func_name);
+ if (original_fdef) {
+ for (const auto& attr : original_fdef->attr()) {
+ (*updated_fdef->mutable_attr())[attr.first] = attr.second;
+ }
+ }
+ if (fld->Find(new_func_name)) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
+ }
+ if (VLOG_IS_ON(4)) {
+ DumpGraphToFile(
+ absl::StrCat("extract_outside_compilation_for_func_after_",
+ func_name),
+ *g, fld);
+ }
}
return ret_status;
@@ -2339,16 +2433,18 @@
&has_outside_compilation));
*modified |= has_outside_compilation;
- string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
- Node* pivot_node = node_name_index[pivot_name];
- TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
- g, fld, host_graph_func_name, n, pivot_node));
+ if (has_outside_compilation) {
+ string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
+ Node* pivot_node = node_name_index[pivot_name];
+ TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
+ g, fld, host_graph_func_name, n, pivot_node));
- TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
+ TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
- for (auto shape_inference_graph_name : shape_inference_graphs) {
- TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
- g, pivot_node, fld));
+ for (auto shape_inference_graph_name : shape_inference_graphs) {
+ TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
+ shape_inference_graph_name, g, pivot_node, fld));
+ }
}
}
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index 049ee82..907fe20 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -418,16 +418,8 @@
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
- // Check host graph is empty.
- std::unique_ptr<FunctionBody> host_fbody;
- AttrValue device_ordinal_temp_value;
- device_ordinal_temp_value.set_i(0);
- protobuf::Map<string, AttrValue> host_func_attrs;
- host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
- TF_CHECK_OK(FunctionDefToBodyHelper(
- *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
- Graph *host_graph = host_fbody->graph;
- EXPECT_EQ(host_graph->num_nodes(), 2);
+ // Check host graph is not created.
+ EXPECT_EQ(fld.Find("host_graph"), nullptr);
}
REGISTER_OP("XlaSendToHost")