Only perform outside compilation related logic if there are outside compilation nodes.

PiperOrigin-RevId: 273332365
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index b35e08fb1..7aa7f17 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -1010,8 +1010,10 @@
     protobuf::Map<string, AttrValue> attrs;
     attrs["_device_ordinal"] = device_ordinal_attr;
     std::unique_ptr<FunctionBody> host_fbody;
-    TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
-        *fld->Find(host_func), AttrSlice(&attrs), fld, &host_fbody));
+    const FunctionDef* host_fdef = fld->Find(host_func);
+    TF_RET_CHECK(host_fdef);
+    TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
+                                               fld, &host_fbody));
 
     // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
     // reachable from sink node so all nodes will be copied.
@@ -1121,7 +1123,9 @@
   protobuf::Map<string, AttrValue> attrs;
   attrs["_device_ordinal"] = device_ordinal_attr;
   std::unique_ptr<FunctionBody> fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(host_graph_func_name),
+  const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
+  TF_RET_CHECK(host_graph_func);
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
                                              AttrSlice(&attrs), fld, &fbody));
   Graph* host_graph = fbody->graph;
 
@@ -1197,8 +1201,11 @@
   protobuf::Map<string, AttrValue> attrs;
   attrs["_device_ordinal"] = device_ordinal_attr;
   std::unique_ptr<FunctionBody> fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
-      *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, &fbody));
+  const FunctionDef* shape_inference_graph =
+      fld->Find(shape_inference_graph_name);
+  TF_RET_CHECK(shape_inference_graph);
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
+                                             AttrSlice(&attrs), fld, &fbody));
   Graph* g = fbody->graph;
 
   // Find SendFromHost node.
@@ -1336,8 +1343,9 @@
   protobuf::Map<string, AttrValue> attrs;
   attrs["_device_ordinal"] = device_ordinal_attr;
   std::unique_ptr<FunctionBody> fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(func_name),
-                                             AttrSlice(&attrs), fld, &fbody));
+  const FunctionDef* func = fld->Find(func_name);
+  TF_RETURN_IF_ERROR(
+      FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
   Graph* g = fbody->graph;
 
   // Find or create the key placeholder node.
@@ -1460,9 +1468,10 @@
     const string& while_node_name, const string& host_transfer_key) {
   // Instantiate the loop cond function.
   std::unique_ptr<FunctionBody> fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()),
-                                             AttrSlice(&loop_cond_func.attr()),
-                                             fld, &fbody));
+  const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name());
+  TF_RET_CHECK(loop_cond_fdef);
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+      *loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody));
   Graph* g = fbody->graph;
 
   // Find the _Retval node and the loop cond node.
@@ -1527,8 +1536,9 @@
   protobuf::Map<string, AttrValue> attrs;
   attrs["_device_ordinal"] = device_ordinal_temp_value;
   std::unique_ptr<FunctionBody> cond_fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
-      *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, &cond_fbody));
+  const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
+                                             fld, &cond_fbody));
   Graph* cond_graph = cond_fbody->graph;
   Node* key_arg = nullptr;
   for (Node* n : cond_graph->nodes()) {
@@ -1602,8 +1612,10 @@
   protobuf::Map<string, AttrValue> attrs;
   attrs["_device_ordinal"] = device_ordinal_temp_value;
   std::unique_ptr<FunctionBody> body_fbody;
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
-      *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, &body_fbody));
+  const FunctionDef* body_host_func = fld->Find(body_host_func_name);
+  TF_RET_CHECK(body_host_func);
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
+                                             fld, &body_fbody));
   Graph* body_graph = body_fbody->graph;
   Node* key_arg = nullptr;
   for (Node* n : body_graph->nodes()) {
@@ -1880,12 +1892,16 @@
   *has_outside_compilation = true;
 
   // Change If node to call the new functions.
-  then_branch.set_name(then_branch_xla_func_name);
-  n->ClearAttr("then_branch");
-  n->AddAttr("then_branch", then_branch);
-  else_branch.set_name(else_branch_xla_func_name);
-  n->ClearAttr("else_branch");
-  n->AddAttr("else_branch", else_branch);
+  if (then_branch_has_outside_compilation) {
+    then_branch.set_name(then_branch_xla_func_name);
+    n->ClearAttr("then_branch");
+    n->AddAttr("then_branch", then_branch);
+  }
+  if (else_branch_has_outside_compilation) {
+    else_branch.set_name(else_branch_xla_func_name);
+    n->ClearAttr("else_branch");
+    n->AddAttr("else_branch", else_branch);
+  }
 
   string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
 
@@ -1905,6 +1921,43 @@
   g->AddControlEdge(send_pred_node, n);
 
   // Build host side graph for the "If" node.
+  // If then/else branch does not have outside compilation, we won't build host
+  // graph for the branch. But here we need a host graph for both branches, so
+  // we need to create a no-op host graph.
+  if (!then_branch_has_outside_compilation) {
+    std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
+    std::vector<string> then_branch_host_graphs;
+    TF_RETURN_IF_ERROR(ConstructHostGraph(
+        xla_cluster_name, outside_compilation_attr_name,
+        then_branch_host_graphs, fld, &then_branch_host_graph));
+    FunctionDef then_branch_host_fdef;
+    TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
+                                          then_branch_host_func_name,
+                                          &then_branch_host_fdef));
+    if (fld->Find(then_branch_host_func_name)) {
+      TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
+                                              then_branch_host_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
+    }
+  }
+  if (!else_branch_has_outside_compilation) {
+    std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
+    std::vector<string> else_branch_host_graphs;
+    TF_RETURN_IF_ERROR(ConstructHostGraph(
+        xla_cluster_name, outside_compilation_attr_name,
+        else_branch_host_graphs, fld, &else_branch_host_graph));
+    FunctionDef else_branch_host_fdef;
+    TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
+                                          else_branch_host_func_name,
+                                          &else_branch_host_fdef));
+    if (fld->Find(else_branch_host_func_name)) {
+      TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
+                                              else_branch_host_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
+    }
+  }
   string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
   TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
@@ -1952,12 +2005,16 @@
   *has_outside_compilation = true;
 
   // Change While node to call the new functions.
-  cond.set_name(cond_xla_func_name);
-  n->ClearAttr("cond");
-  n->AddAttr("cond", cond);
-  body.set_name(body_xla_func_name);
-  n->ClearAttr("body");
-  n->AddAttr("body", body);
+  if (cond_has_outside_compilation) {
+    cond.set_name(cond_xla_func_name);
+    n->ClearAttr("cond");
+    n->AddAttr("cond", cond);
+  }
+  if (body_has_outside_compilation) {
+    body.set_name(body_xla_func_name);
+    n->ClearAttr("body");
+    n->AddAttr("body", body);
+  }
 
   string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
 
@@ -1969,6 +2026,38 @@
              std::vector<string>{kXlaTokenArgNodeName});
 
   // Build host side graph for the "While" node.
+  if (!cond_has_outside_compilation) {
+    std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
+    std::vector<string> host_graphs;
+    TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
+                                          outside_compilation_attr_name,
+                                          host_graphs, fld, &cond_host_graph));
+    FunctionDef cond_host_fdef;
+    TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
+                                          &cond_host_fdef));
+    if (fld->Find(cond_host_func_name)) {
+      TF_RETURN_IF_ERROR(
+          fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
+    }
+  }
+  if (!body_has_outside_compilation) {
+    std::unique_ptr<Graph> body_host_graph(new Graph(fld));
+    std::vector<string> host_graphs;
+    TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
+                                          outside_compilation_attr_name,
+                                          host_graphs, fld, &body_host_graph));
+    FunctionDef body_host_fdef;
+    TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
+                                          &body_host_fdef));
+    if (fld->Find(body_host_func_name)) {
+      TF_RETURN_IF_ERROR(
+          fld->ReplaceFunction(body_host_func_name, body_host_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
+    }
+  }
   string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
   TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
@@ -2166,146 +2255,151 @@
         *fbody->graph, fld);
   }
 
-  // Find dependencies between outside compilation clusters.
-  TF_ASSIGN_OR_RETURN(auto cluster_deps,
-                      OutsideCompilationClusterDependencies(
-                          fbody->graph, outside_compilation_attr_name));
-
-  // Preprocess edges between different outside compilations. They will be
-  // restored in `ConstructHostGraph()`.
-  TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
-      fbody->graph, outside_compilation_attr_name));
-
-  // Encapsulate outside_compilation cluster into function call node.
   std::unique_ptr<Graph> graph_out;
-  auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
-      xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
-      new_func_name);
-  TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
-      outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
-      /*reuse_existing_functions=*/true, &graph_out, fld));
-
-  // Replace outside_compilation function nodes with HostCompute ops.
-  std::vector<Node*> outside_compilation_nodes;
   std::vector<string> outside_compilation_host_graphs;
   std::vector<string> shape_inference_graphs_to_rewrite;
-  for (Node* n : graph_out->nodes()) {
-    if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
-      outside_compilation_nodes.push_back(n);
-      outside_compilation_host_graphs.push_back(n->name());
+  if (*has_outside_compilation) {
+    // Find dependencies between outside compilation clusters.
+    TF_ASSIGN_OR_RETURN(auto cluster_deps,
+                        OutsideCompilationClusterDependencies(
+                            fbody->graph, outside_compilation_attr_name));
 
-      // If we could not infer shapes for XlaSendFromHost inputs statically, we
-      // will set the "shape_inference_graph" attribute. In that case, copy
-      // outside compilation subgraph as shape inference graph in `fld`.
-      auto shape_inference_graph = absl::make_unique<NameAttrList>();
-      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
-                                     shape_inference_graph.get()));
-      if (!shape_inference_graph->name().empty()) {
-        shape_inference_graphs->push_back(shape_inference_graph->name());
-        shape_inference_graphs_to_rewrite.push_back(
-            shape_inference_graph->name());
+    // Preprocess edges between different outside compilations. They will be
+    // restored in `ConstructHostGraph()`.
+    TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
+        fbody->graph, outside_compilation_attr_name));
 
-        const FunctionDef* xla_fdef = fld->Find(n->name());
-        if (!xla_fdef) {
-          return errors::Internal("Cannot find XLA function ", n->name());
-        }
-        auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
-        shape_inference_fdef->mutable_signature()->set_name(
-            shape_inference_graph->name());
-        if (fld->Find(shape_inference_graph->name())) {
-          TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(),
-                                                  *shape_inference_fdef));
-        } else {
-          TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
+    // Encapsulate outside_compilation cluster into function call node.
+    auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
+        xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+        new_func_name);
+    TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
+        outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
+        /*reuse_existing_functions=*/true, &graph_out, fld));
+
+    // Replace outside_compilation function nodes with HostCompute ops.
+    std::vector<Node*> outside_compilation_nodes;
+    for (Node* n : graph_out->nodes()) {
+      if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
+        outside_compilation_nodes.push_back(n);
+        outside_compilation_host_graphs.push_back(n->name());
+
+        // If we could not infer shapes for XlaSendFromHost inputs statically,
+        // we will set the "shape_inference_graph" attribute. In that case, copy
+        // outside compilation subgraph as shape inference graph in `fld`.
+        auto shape_inference_graph = absl::make_unique<NameAttrList>();
+        TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
+                                       shape_inference_graph.get()));
+        if (!shape_inference_graph->name().empty()) {
+          shape_inference_graphs->push_back(shape_inference_graph->name());
+          shape_inference_graphs_to_rewrite.push_back(
+              shape_inference_graph->name());
+
+          const FunctionDef* xla_fdef = fld->Find(n->name());
+          if (!xla_fdef) {
+            return errors::Internal("Cannot find XLA function ", n->name());
+          }
+          auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
+          shape_inference_fdef->mutable_signature()->set_name(
+              shape_inference_graph->name());
+          if (fld->Find(shape_inference_graph->name())) {
+            TF_RETURN_IF_ERROR(fld->ReplaceFunction(
+                shape_inference_graph->name(), *shape_inference_fdef));
+          } else {
+            TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
+          }
         }
       }
     }
-  }
-  std::map<string, Node*> host_compute_nodes;
-  for (Node* n : outside_compilation_nodes) {
-    TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
-    auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
-        graph_out.get(), n, host_compute_core, *cluster_deps);
-    TF_RETURN_IF_ERROR(host_compute_node_or.status());
-    Node* host_compute_node = host_compute_node_or.ValueOrDie();
-    host_compute_nodes[host_compute_node->name()] = host_compute_node;
-  }
-  // For XlaHostCompute nodes with dependencies, add control edges between them
-  // so XlaCompiler can handle them in correct order.
-  for (auto iter : host_compute_nodes) {
-    Node* host_compute_node = iter.second;
-    std::vector<string> token_input_node_names;
-    TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
-                                   kXlaTokenInputNodesAttrName,
-                                   &token_input_node_names));
-    for (const string& node_name : token_input_node_names) {
-      if (node_name == kXlaTokenArgNodeName) {
-        continue;
-      }
+    std::map<string, Node*> host_compute_nodes;
+    for (Node* n : outside_compilation_nodes) {
+      TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
+      auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
+          graph_out.get(), n, host_compute_core, *cluster_deps);
+      TF_RETURN_IF_ERROR(host_compute_node_or.status());
+      Node* host_compute_node = host_compute_node_or.ValueOrDie();
+      host_compute_nodes[host_compute_node->name()] = host_compute_node;
+    }
+    // For XlaHostCompute nodes with dependencies, add control edges between
+    // them so XlaCompiler can handle them in correct order.
+    for (auto iter : host_compute_nodes) {
+      Node* host_compute_node = iter.second;
+      std::vector<string> token_input_node_names;
+      TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
+                                     kXlaTokenInputNodesAttrName,
+                                     &token_input_node_names));
+      for (const string& node_name : token_input_node_names) {
+        if (node_name == kXlaTokenArgNodeName) {
+          continue;
+        }
 
-      auto iter = host_compute_nodes.find(node_name);
-      TF_RET_CHECK(iter != host_compute_nodes.end());
-      graph_out->AddControlEdge(iter->second, host_compute_node);
+        auto iter = host_compute_nodes.find(node_name);
+        TF_RET_CHECK(iter != host_compute_nodes.end());
+        graph_out->AddControlEdge(iter->second, host_compute_node);
+      }
     }
   }
 
   // Handle nodes with associated functions.
+  Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
-      graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name,
-      xla_cluster_name, host_compute_core, flr, fld,
-      &outside_compilation_host_graphs, shape_inference_graphs,
-      has_outside_compilation));
+      g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+      host_compute_core, flr, fld, &outside_compilation_host_graphs,
+      shape_inference_graphs, has_outside_compilation));
 
-  // Construct host graph.
-  std::unique_ptr<Graph> host_graph;
-  TF_RETURN_IF_ERROR(
-      ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
-                         outside_compilation_host_graphs, fld, &host_graph));
-  auto host_graph_fdef = absl::make_unique<FunctionDef>();
-  TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
-                                        HostGraphControlRetMapping,
-                                        host_graph_fdef.get()));
-  if (fld->Find(host_graph_func_name)) {
+  if (*has_outside_compilation) {
+    // Construct host graph.
+    std::unique_ptr<Graph> host_graph;
     TF_RETURN_IF_ERROR(
-        fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
-  } else {
-    TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
-  }
-
-  // Shape inference graphs might contain Placeholder nodes for outside
-  // compilation to outside compilation edges. Rewrite shape inference graphs
-  // to remove such nodes.
-  for (const string& shape_inference_graph :
-       shape_inference_graphs_to_rewrite) {
-    TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph,
-                                                  host_graph.get(),
-                                                  /*pivot_node=*/nullptr, fld));
-  }
-
-  // Remove the outside compilation graphs from function library.
-  for (const string& func : outside_compilation_host_graphs) {
-    TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
-  }
-
-  // Replace original function.
-  auto updated_fdef = absl::make_unique<FunctionDef>();
-  TF_RETURN_IF_ERROR(
-      GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get()));
-  const FunctionDef* original_fdef = fld->Find(func_name);
-  if (original_fdef) {
-    for (const auto& attr : original_fdef->attr()) {
-      (*updated_fdef->mutable_attr())[attr.first] = attr.second;
+        ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
+                           outside_compilation_host_graphs, fld, &host_graph));
+    auto host_graph_fdef = absl::make_unique<FunctionDef>();
+    TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
+                                          HostGraphControlRetMapping,
+                                          host_graph_fdef.get()));
+    if (fld->Find(host_graph_func_name)) {
+      TF_RETURN_IF_ERROR(
+          fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
     }
-  }
-  if (fld->Find(new_func_name)) {
-    TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
-  } else {
-    TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
-  }
-  if (VLOG_IS_ON(4)) {
-    DumpGraphToFile(
-        absl::StrCat("extract_outside_compilation_for_func_after_", func_name),
-        *graph_out, fld);
+
+    // Shape inference graphs might contain Placeholder nodes for outside
+    // compilation to outside compilation edges. Rewrite shape inference graphs
+    // to remove such nodes.
+    for (const string& shape_inference_graph :
+         shape_inference_graphs_to_rewrite) {
+      TF_RETURN_IF_ERROR(
+          RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
+                                     /*pivot_node=*/nullptr, fld));
+    }
+
+    // Remove the outside compilation graphs from function library.
+    for (const string& func : outside_compilation_host_graphs) {
+      TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
+    }
+
+    // Replace original function.
+    auto updated_fdef = absl::make_unique<FunctionDef>();
+    TF_RETURN_IF_ERROR(
+        GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
+    const FunctionDef* original_fdef = fld->Find(func_name);
+    if (original_fdef) {
+      for (const auto& attr : original_fdef->attr()) {
+        (*updated_fdef->mutable_attr())[attr.first] = attr.second;
+      }
+    }
+    if (fld->Find(new_func_name)) {
+      TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
+    } else {
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
+    }
+    if (VLOG_IS_ON(4)) {
+      DumpGraphToFile(
+          absl::StrCat("extract_outside_compilation_for_func_after_",
+                       func_name),
+          *g, fld);
+    }
   }
 
   return ret_status;
@@ -2339,16 +2433,18 @@
         &has_outside_compilation));
     *modified |= has_outside_compilation;
 
-    string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
-    Node* pivot_node = node_name_index[pivot_name];
-    TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
-        g, fld, host_graph_func_name, n, pivot_node));
+    if (has_outside_compilation) {
+      string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
+      Node* pivot_node = node_name_index[pivot_name];
+      TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
+          g, fld, host_graph_func_name, n, pivot_node));
 
-    TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
+      TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
 
-    for (auto shape_inference_graph_name : shape_inference_graphs) {
-      TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
-                                                    g, pivot_node, fld));
+      for (auto shape_inference_graph_name : shape_inference_graphs) {
+        TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
+            shape_inference_graph_name, g, pivot_node, fld));
+      }
     }
   }
 
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index 049ee82..907fe20 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -418,16 +418,8 @@
       host_compute_core, &fld, &shape_inference_graphs,
       &has_outside_compilation));
 
-  // Check host graph is empty.
-  std::unique_ptr<FunctionBody> host_fbody;
-  AttrValue device_ordinal_temp_value;
-  device_ordinal_temp_value.set_i(0);
-  protobuf::Map<string, AttrValue> host_func_attrs;
-  host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
-  TF_CHECK_OK(FunctionDefToBodyHelper(
-      *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
-  Graph *host_graph = host_fbody->graph;
-  EXPECT_EQ(host_graph->num_nodes(), 2);
+  // Check host graph is not created.
+  EXPECT_EQ(fld.Find("host_graph"), nullptr);
 }
 
 REGISTER_OP("XlaSendToHost")