Also functionalize control flow in functions for UpgradeLegacyGraph

1) Previously, UpgradeLegacyGraph only functionalized control flow in the graph,
   not control flow in functions that are called from the graph, which caused
   problems because subsequent steps didn't expect and correctly handle
   remaining v1 control flow. Now such control flow is functionalized, too.
2) Refactored existing functionalization code.
3) Fixed bug in existing functionalization code: In certain cases the definition
   of a modified function (after its control flow was functionalized) could not
   be found when the calling graph node was rewritten.

PiperOrigin-RevId: 326755095
Change-Id: I7163f7f0359f5de978ec5d764949978a2341cbaa
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 94ddf76..51f6374 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -177,7 +177,8 @@
       restrict_functionalization_to_tpu_nodes
           ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
           : NodeFilter{};
-  return FunctionalizeControlFlow(graph, flib_def, node_filter);
+  return FunctionalizeControlFlow(graph, flib_def, node_filter,
+                                  /*include_functions=*/true);
 }
 
 // Stateful helper class to import a TensorFlow model into an MLIR Module.
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 10b26f9..596fa8e 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -46,12 +46,254 @@
 
 namespace tensorflow {
 
+// Helper functions for functionalizing control flow in functions.
+
+// Maps function name to
+// - new function name, if the function body was functionalized
+// - absl::nullopt, if not
+using FuncMap = std::map<string, absl::optional<string>>;
+using FuncMapIter = std::map<string, absl::optional<string>>::const_iterator;
+
+// Returns whether function has been processed before.
+bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) {
+  return func_iter != func_map->end();
+}
+
+// Returns whether function has been modified (i.e., functionalized) before.
+bool FunctionHasBeenModified(FuncMapIter func_iter) {
+  return func_iter->second.has_value();
+}
+
+// Returns a name for the new functionalized version of a function.
+string GetNewFunctionName(
+    const string& func_name, Node* n,
+    AssociatedFunctionInfo::AssociatedFunctionType func_type,
+    FunctionLibraryDefinition* fld) {
+  // For SymbolicGradient, `func_name` is always "SymbolicGradient" which
+  // is not very informative. Use node name instead.
+  return (
+      func_type ==
+              AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient
+          ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"))
+          : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_")));
+}
+
+// Returns name to which a modified function has been mapped.
+const string& GetMappedFunctionName(FuncMapIter func_iter) {
+  DCHECK(func_iter->second.has_value());
+  return func_iter->second.value();
+}
+
+// Updates `func_map` with function given by `canonicalized_name`.
+void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name,
+                       const string& new_func_name, bool function_modified) {
+  // If function was modified store its new name, otherwise add empty entry to
+  // record that function has been processed and does not need to be rewritten.
+  (*func_map)[canonicalized_name] =
+      function_modified ? absl::make_optional(new_func_name) : absl::nullopt;
+}
+
+// Adds new function def to graph's function library if necessary.
+Status AddFunctionDefToGraphLibrary(
+    const string& func_name, const AssociatedFunctionInfo& associated_function,
+    Graph* graph, FunctionLibraryDefinition* fld) {
+  const OpRegistrationData* op_reg_data;
+  // We have to be careful with adding the function def since there are three
+  // different `OpRegistryInterface`s involved here:
+  // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`.
+  // We have already added the function def to `fld` before calling this
+  // function but for the subsequent `RewriteAssociatedFunction` call we need
+  // the function def to be in one of the other two registries, otherwise
+  // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case
+  // because it cannot find the associated function def.
+  // On the other hand, we should not add the function def if it is already
+  // contained in one of the last two registries, this would lead to errors when
+  // the function def is already in one registry and we try to add it to the
+  // other one (if we try to add it to the same it's fine). This can happen in
+  // cases where one of the last two registries is identical to `fld` (which we
+  // already updated).
+  // Therefore, before adding the function def we have to check if it's already
+  // contained in either `graph->flib_def()` or
+  // `graph->flib_def().default_registry()` which is done in the following line
+  // (we have to use `LookUp` instead of `Contains` or `Find` because the latter
+  // both don't check the default registry).
+  if (graph->flib_def().LookUp(func_name, &op_reg_data).ok())
+    return Status::OK();
+
+  const FunctionDef* new_fdef = fld->Find(func_name);
+  DCHECK(new_fdef != nullptr);
+  FunctionDefLibrary fdef_lib;
+  *(fdef_lib.add_function()) = *new_fdef;
+  return graph->AddFunctionLibrary(fdef_lib);
+}
+
+// Functionalizes function given by `func_name`. Update `func_map` accordingly.
+Status FunctionalizeControlFlowForFunction(
+    const string& func_name, const string& new_func_name,
+    const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+    FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+    FuncMap* func_map, bool* function_modified,
+    const NodeFilter& node_filter = {});
+
+// Functionalizes all functions that are (directly or indirectly) associated to
+// any node in `graph`. Adds processed functions to `func_map`.
+Status FunctionalizeControlFlowForNodeAssociatedFunctions(
+    FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld,
+    FunctionLibraryRuntime* flr, bool* any_function_modified,
+    const NodeFilter& node_filter) {
+  std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
+      nodes_to_associated_functions;
+  for (auto* n : graph->nodes()) {
+    auto associated_functions = GetAssociatedFunctions(*n, fld);
+    if (!associated_functions.empty()) {
+      nodes_to_associated_functions.push_back({n, associated_functions});
+    }
+  }
+  for (const auto& pair : nodes_to_associated_functions) {
+    Node* n = pair.first;
+    auto associated_functions = pair.second;
+    for (auto& associated_function : associated_functions) {
+      // Note that if `n` is a function call node, then potential calls of
+      // `RewriteAssociatedFunction` below might delete `n` and create a new
+      // node instead, making `n` an invalid pointer. That's fine because in
+      // that case `n` only has one associated function, so this loop has only
+      // one iteration and we don't use `n` again after the rewrite.
+      // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed
+      // below.
+      DCHECK(associated_function.type() !=
+                 AssociatedFunctionInfo::kFunctionCallNode ||
+             associated_functions.size() == 1);
+
+      // Process one node-function-pair.
+      string func_name = associated_function.func_name();
+      string canonicalized_name =
+          Canonicalize(func_name, AttrSlice(&associated_function.attrs()));
+      auto func_iter = func_map->find(canonicalized_name);
+      string new_func_name;
+      if (FunctionHasBeenProcessed(func_iter, func_map)) {
+        if (FunctionHasBeenModified(func_iter)) {
+          *any_function_modified = true;
+          new_func_name = GetMappedFunctionName(func_iter);
+          TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+              graph, n, fld, associated_function, new_func_name));
+        }
+        continue;
+      }
+      // Function is processed for the first time.
+      bool function_modified = false;
+      new_func_name =
+          GetNewFunctionName(func_name, n, associated_function.type(), fld);
+      // Perform functionalization for current function.
+      TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+          func_name, new_func_name, associated_function.attrs(), fld, flr,
+          func_map, &function_modified, node_filter));
+      UpdateFunctionMap(func_map, canonicalized_name, new_func_name,
+                        function_modified);
+      if (function_modified) {
+        *any_function_modified = true;
+        TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary(
+            new_func_name, associated_function, graph, fld));
+        TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+            graph, n, fld, associated_function, new_func_name));
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status FunctionalizeControlFlowForFunction(
+    const string& func_name, const string& new_func_name,
+    const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+    FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+    FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) {
+  *function_modified = false;
+
+  // Convert the function to a graph.
+  FunctionLibraryRuntime::Handle handle;
+  TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
+  Status ret_status = Status::OK();
+  auto cleanup_handle = gtl::MakeCleanup([&]() {
+    auto s = flr->ReleaseHandle(handle);
+    if (!s.ok()) {
+      ret_status.Update(s);
+    }
+  });
+  const FunctionBody* body = flr->GetFunctionBody(handle);
+  Graph* g = body->graph;
+
+  // Check if the graph has Switch or Merge node.
+  bool has_switch_or_merge = false;
+  for (Node* n : body->graph->nodes()) {
+    // Skip nodes that are filtered out.
+    if (node_filter && !node_filter(n)) continue;
+    if (n->type_string() == "Switch" || n->type_string() == "Merge") {
+      has_switch_or_merge = true;
+      break;
+    }
+  }
+  // Before functionalizing control flow in `g` we functionalize control flow
+  // in functions (directly or indirectly) associated with nodes in `g`.
+  TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
+      func_map, g, fld, flr, function_modified, node_filter));
+
+  if (has_switch_or_merge) {
+    *function_modified = true;
+
+    // Functionalize the function body.
+    if (VLOG_IS_ON(4)) {
+      DumpGraphToFile(
+          absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+          *g, fld);
+    }
+    TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter));
+    if (VLOG_IS_ON(4)) {
+      DumpGraphToFile(
+          absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
+          fld);
+    }
+  }
+  if (*function_modified) {
+    // Add rewritten FunctionDef into library.
+    FunctionDef functionalized_fdef;
+    TF_RETURN_IF_ERROR(
+        GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
+    if (func_name == new_func_name) {
+      VLOG(2) << "Replacing function " << func_name;
+      TF_RETURN_IF_ERROR(
+          fld->ReplaceFunction(new_func_name, functionalized_fdef));
+    } else {
+      VLOG(2) << "Adding function " << new_func_name;
+      TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+    }
+  }
+
+  return ret_status;
+}
+
 Status FunctionalizeControlFlow(Graph* graph,
                                 FunctionLibraryDefinition* library,
-                                const NodeFilter& node_filter) {
+                                const NodeFilter& node_filter,
+                                bool include_functions) {
   VLOG(2) << "FunctionalizeControlFlow (initial): "
           << DumpGraphToFile("functionalize_initial", *graph, library);
 
+  if (include_functions) {
+    // Functionalize control flow in functions that are (directly or indirectly)
+    // associated with a node in `graph`.
+    auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
+        /*device_mgr=*/nullptr, tensorflow::Env::Default(),
+        /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library,
+        tensorflow::OptimizerOptions());
+    // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice`
+    // (because we constructed it with `device_mgr = nullptr`).
+    FunctionLibraryRuntime* flr =
+        pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+    FuncMap func_map;
+    bool modified = false;
+    TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
+        &func_map, graph, library, flr, &modified, node_filter));
+  }
   // Functionalize and remove while loops from graph.
   TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
 
@@ -68,153 +310,19 @@
 
 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
                                            FunctionLibraryDefinition* library,
-                                           const NodeFilter& node_filter) {
+                                           const NodeFilter& node_filter,
+                                           bool include_functions) {
   FunctionDefLibrary function_lib = graph_def->library();
   Graph graph(OpRegistry::Global());
 
   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
-  TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter));
+  TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter,
+                                              include_functions));
   graph.ToGraphDef(graph_def);
   std::swap(*graph_def->mutable_library(), function_lib);
   return Status::OK();
 }
 
-Status FunctionalizeControlFlowForFunction(
-    const string& func_name, const string& new_func_name,
-    const protobuf::Map<string, tensorflow::AttrValue>& attrs,
-    FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
-    std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
-    bool* modified) {
-  *modified = false;
-
-  // Convert the function to Graph.
-  FunctionLibraryRuntime::Handle handle;
-  TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
-  Status ret_status = Status::OK();
-  auto cleanup_handle = gtl::MakeCleanup([&]() {
-    auto s = flr->ReleaseHandle(handle);
-    if (!s.ok()) {
-      ret_status.Update(s);
-    }
-  });
-  const FunctionBody* body = flr->GetFunctionBody(handle);
-  Graph* g = body->graph;
-
-  // Check if the graph has Switch or Merge node.
-  bool has_switch_or_merge = false;
-  for (Node* n : body->graph->nodes()) {
-    if (n->type_string() == "Switch" || n->type_string() == "Merge") {
-      has_switch_or_merge = true;
-      break;
-    }
-  }
-  // We cannot return here directly if the graph has no Switch/Merge.
-  // It might contain function call nodes, or If/While nodes with Switch/Merge
-  // in function body. We still need to rewrite those functions and modify
-  // corresponding nodes.
-
-  // If any node has associated functions, functionalize them first.
-  // Gather nodes with associated functions first, because rewriting those nodes
-  // might involve node deletion/addition. Avoid modifying nodes while iterating
-  // it.
-  std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
-      nodes_to_associated_functions;
-  for (auto* n : g->nodes()) {
-    auto associated_functions = GetAssociatedFunctions(*n, fld);
-    if (!associated_functions.empty()) {
-      nodes_to_associated_functions.push_back({n, associated_functions});
-    }
-  }
-  for (const auto& iter : nodes_to_associated_functions) {
-    Node* n = iter.first;
-    auto associated_functions = iter.second;
-    for (auto& associated_function : associated_functions) {
-      string name = associated_function.func_name();
-      string canonicalized_name =
-          Canonicalize(name, AttrSlice(&associated_function.attrs()));
-      auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
-      string new_name;
-      bool function_modified;
-      if (iter != canonicalized_name_to_new_name->end()) {
-        // If we already processed this function, check if it was rewritten. If
-        // the function was rewritten, the entry will be non-empty. Otherwise
-        // the entry will be empty.
-        function_modified = iter->second.has_value();
-        if (function_modified) {
-          new_name = iter->second.value();
-        }
-      } else {
-        if (associated_function.type() ==
-            AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
-          // For SymbolicGradient, `name` is always "SymbolicGradient",
-          // which is not very informative. Use node name instead.
-          new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
-        } else {
-          new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
-        }
-        TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
-            name, new_name, associated_function.attrs(), fld, flr,
-            canonicalized_name_to_new_name, &function_modified));
-        if (function_modified) {
-          // If the function was rewritten, add an non-empty entry. So later we
-          // know we have processed this function, and it was rewritten into
-          // another function.
-          (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
-        } else {
-          // If the function was not rewritten, add an empty entry. So later
-          // we know we have processed this function, and it does not need to be
-          // rewritten.
-          (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
-        }
-      }
-      if (function_modified) {
-        *modified = true;
-
-        // Notice that if "n" is a function call, RewriteAssociatedFunction()
-        // will delete it and create a new node instead, making "n" an invalid
-        // pointer. That's fine because in that case, associated_functions will
-        // only have one member and the loop will only run once.
-        TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
-            g, n, fld, associated_function, new_name));
-      }
-    }
-  }
-
-  if (has_switch_or_merge) {
-    *modified = true;
-
-    // Functionalize the function body.
-    if (VLOG_IS_ON(4)) {
-      DumpGraphToFile(
-          absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
-          *g, fld);
-    }
-    TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
-    if (VLOG_IS_ON(4)) {
-      DumpGraphToFile(
-          absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
-          fld);
-    }
-  }
-
-  if (*modified) {
-    // Add rewritten FunctionDef into library.
-    FunctionDef functionalized_fdef;
-    TF_RETURN_IF_ERROR(
-        GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
-    if (func_name == new_func_name) {
-      VLOG(2) << "Replacing function " << func_name;
-      TF_RETURN_IF_ERROR(
-          fld->ReplaceFunction(new_func_name, functionalized_fdef));
-    } else {
-      VLOG(2) << "Adding function " << new_func_name;
-      TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
-    }
-  }
-
-  return ret_status;
-}
-
 Status FunctionalizeControlFlowForXlaPass::Run(
     const GraphOptimizationPassOptions& options) {
   Graph* graph = options.graph->get();
@@ -241,7 +349,7 @@
           // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
           {"XlaLaunch", "function"},
       };
-  std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
+  FuncMap func_map;
   bool fld_modified = false;
   for (Node* n : graph->nodes()) {
     auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
@@ -258,7 +366,7 @@
     bool modified;
     TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
         func.name(), new_func_name, func.attr(), options.flib_def, flr,
-        &canonicalized_name_to_new_name, &modified));
+        &func_map, &modified));
     if (modified) {
       n->ClearAttr(func_attr);
       func.set_name(new_func_name);
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index f9e751e..46abae2 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -30,6 +30,13 @@
 //
 // If `node_filter` is defined, then only loops and conditions for whose
 // nodes `node_filter` returns true are functionalized.
+
+// If `include_functions` is true, then loops and conditions inside of functions
+// that are associated with nodes in `graph` (e.g., a function called from a
+// node in `graph`) are also functionalized, otherwise they are not.
+// This also handles transitive cases, e.g., a function body will be
+// functionalized when it is called in another function that is called by some
+// node in `graph` (and so on). The node filter also applies here.
 //
 // Precondition:
 // For any node in a loop or condition for which `node_filter` returns true,
@@ -43,11 +50,13 @@
 // satisfies the above conditions.
 Status FunctionalizeControlFlow(Graph* graph,
                                 FunctionLibraryDefinition* library,
-                                const NodeFilter& node_filter = {});
+                                const NodeFilter& node_filter = {},
+                                bool include_functions = false);
 
 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
                                            FunctionLibraryDefinition* library,
-                                           const NodeFilter& node_filter = {});
+                                           const NodeFilter& node_filter = {},
+                                           bool include_functions = false);
 
 // This pass looks at the graph, and turns V1 control flow structure
 // (Switch/Merge/etc.) into V2 control flow structure (If/While).
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index 79a042a..951ebdd 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -27,12 +27,15 @@
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/validate.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/dump_graph.h"
 #include "tensorflow/core/util/equal_graph_def.h"
 
 namespace tensorflow {
@@ -63,18 +66,41 @@
 //     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
 //     lambda: math_ops.add(x, 23))
 //
-// Tests different node filters.
-class ConditionalTestFixture : public ::testing::TestWithParam<bool> {
+// Tests different node filters and functionalization inside of a function.
+class ConditionalTestFixture
+    : public ::testing::TestWithParam<std::tuple<bool, bool>> {
  protected:
-  void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); }
+  void SetUp() override {
+    restrict_to_tpu_nodes_ = std::get<0>(GetParam());
+    wrap_condition_in_function_ = std::get<1>(GetParam());
+  }
   void RunTest();
 
  private:
+  void BuildCondGraph(Graph* cond_graph);
+  void CheckGraphDef(const GraphDef& graph_def,
+                     const FunctionLibraryDefinition& library);
+
   bool restrict_to_tpu_nodes_ = false;
+  bool wrap_condition_in_function_ = false;
 };
 
-void ConditionalTestFixture::RunTest() {
-  Graph graph(OpRegistry::Global());
+TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
+
+INSTANTIATE_TEST_SUITE_P(
+    FunctionalizeControlFlow, ConditionalTestFixture,
+    ::testing::Combine(::testing::Bool(), ::testing::Bool()),
+    [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
+           info) {
+      bool restrict_to_tpu_nodes = std::get<0>(info.param);
+      bool wrap_cond_in_function = std::get<1>(info.param);
+      string name =
+          absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter",
+                       wrap_cond_in_function ? "_in_function" : "_in_graph");
+      return name;
+    });
+
+void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) {
   {
     Scope scope = Scope::NewRootScope().ExitOnError();
 
@@ -102,13 +128,117 @@
     auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
                             std::initializer_list<Input>{add, mul});
 
-    TF_EXPECT_OK(scope.ToGraph(&graph));
+    TF_EXPECT_OK(scope.ToGraph(cond_graph));
 
     // Set `_tpu_replicate` attribute for all nodes.
-    for (Node* n : graph.nodes()) {
+    for (Node* n : cond_graph->nodes()) {
       n->AddAttr("_tpu_replicate", "cluster");
     }
   }
+}
+
+void ConditionalTestFixture::CheckGraphDef(
+    const GraphDef& graph_def, const FunctionLibraryDefinition& library) {
+  string op_name;
+  NameAttrList then_fn;
+  NameAttrList else_fn;
+  TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
+  InstantiationResultForTest else_result;
+  TF_EXPECT_OK(
+      InstantiateFunctionForTest(else_fn.name(), library, &else_result));
+
+  // Outer graph
+  {
+    Scope scope = Scope::NewRootScope().ExitOnError();
+    auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
+    auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
+    auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
+    auto if_op =
+        ops::If(scope.WithOpName(op_name), less,
+                std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
+                else_fn, ops::If::OutputShapes({PartialTensorShape()}));
+    auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
+    GraphDef expected;
+    TF_EXPECT_OK(scope.ToGraphDef(&expected));
+    TF_EXPECT_GRAPH_EQ(expected, graph_def);
+  }
+
+  // then body.
+  {
+    Scope scope = Scope::NewRootScope().ExitOnError();
+    auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
+    auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
+    auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
+    auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
+    auto cond = ops::Const(
+        scope.WithOpName("cond").WithControlDependencies(identity), 17);
+    auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
+    auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
+
+    GraphDef expected;
+    TF_EXPECT_OK(scope.ToGraphDef(&expected));
+
+    InstantiationResultForTest result;
+    TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
+
+    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
+    EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
+    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+  }
+
+  // else body.
+  {
+    Scope scope = Scope::NewRootScope().ExitOnError();
+    auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
+    auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
+    auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
+    auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
+    auto cond_1 = ops::Const(
+        scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
+    auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
+    auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
+
+    GraphDef expected;
+    TF_EXPECT_OK(scope.ToGraphDef(&expected));
+
+    InstantiationResultForTest result;
+    TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
+
+    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
+    EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
+    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
+  }
+}
+
+void ConditionalTestFixture::RunTest() {
+  Graph graph(OpRegistry::Global());
+  if (wrap_condition_in_function_) {
+    // Wrap condition in a function which is called from `graph`.
+    Scope scope = Scope::NewRootScope().ExitOnError();
+    auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
+
+    Graph cond_graph(OpRegistry::Global());
+    BuildCondGraph(&cond_graph);
+
+    FunctionDef cond_fdef;
+    TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef));
+
+    FunctionDefLibrary fdef_lib;
+    *(fdef_lib.add_function()) = cond_fdef;
+    TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
+    NodeDef cond_fn;
+    cond_fn.set_name("cond_node");
+    cond_fn.set_op("cond_fn");
+    *(cond_fn.add_input()) = "source";
+    Status status;
+    scope.graph()->AddNode(cond_fn, &status);
+    TF_ASSERT_OK(status);
+    TF_ASSERT_OK(scope.ToGraph(&graph));
+  } else {
+    // Build condition in `graph`.
+    BuildCondGraph(&graph);
+  }
+  FunctionLibraryDefinition library(graph.flib_def());
   // If `restrict_to_tpu_nodes_` is true let filter function return true for
   // `_tpu_replicate` nodes.
   NodeFilter node_filter =
@@ -116,99 +246,47 @@
           ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
           : NodeFilter{};
 
-  FunctionLibraryDefinition library(OpRegistry::Global(), {});
   GraphDef optimized_graph_def;
   graph.ToGraphDef(&optimized_graph_def);
-  TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
-                                                   &library, node_filter));
-  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter));
-  GraphDef converted_graph_def;
-  graph.ToGraphDef(&converted_graph_def);
+  TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
+      &optimized_graph_def, &library, node_filter,
+      /*include_functions=*/wrap_condition_in_function_));
+  TF_ASSERT_OK(FunctionalizeControlFlow(
+      &graph, &library, node_filter,
+      /*include_functions=*/wrap_condition_in_function_));
 
-  for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
-    string op_name;
-    NameAttrList then_fn;
-    NameAttrList else_fn;
-    TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
-    InstantiationResultForTest else_result;
-    TF_EXPECT_OK(
-        InstantiateFunctionForTest(else_fn.name(), library, &else_result));
+  if (wrap_condition_in_function_) {
+    // Check if function body was functionalized.
+    auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
+        /*device_mgr=*/nullptr, tensorflow::Env::Default(),
+        /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library,
+        tensorflow::OptimizerOptions());
+    FunctionLibraryRuntime* flr =
+        pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+    FunctionLibraryRuntime::Handle handle;
 
-    // Outer graph
-    {
-      Scope scope = Scope::NewRootScope().ExitOnError();
-      auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
-      auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
-      auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
-      auto if_op =
-          ops::If(scope.WithOpName(op_name), less,
-                  std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
-                  else_fn, ops::If::OutputShapes({PartialTensorShape()}));
-      auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
-      GraphDef expected;
-      TF_EXPECT_OK(scope.ToGraphDef(&expected));
-      TF_EXPECT_GRAPH_EQ(expected, graph_def);
+    // Functionalized function name is the type string of `cond_node`.
+    string func_name;
+    for (Node* n : graph.nodes()) {
+      if (n->name() == "cond_node") {
+        func_name = n->type_string();
+        break;
+      }
     }
-
-    // then body.
-    {
-      Scope scope = Scope::NewRootScope().ExitOnError();
-      auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
-      auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
-      auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
-      auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
-      auto cond = ops::Const(
-          scope.WithOpName("cond").WithControlDependencies(identity), 17);
-      auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
-      auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
-
-      GraphDef expected;
-      TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
-      InstantiationResultForTest result;
-      TF_EXPECT_OK(
-          InstantiateFunctionForTest(then_fn.name(), library, &result));
-
-      EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
-      EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
-                result.arg_types);
-      TF_EXPECT_GRAPH_EQ(expected, result.gdef);
-    }
-
-    // else body.
-    {
-      Scope scope = Scope::NewRootScope().ExitOnError();
-      auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
-      auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
-      auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
-      auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
-      auto cond_1 = ops::Const(
-          scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
-      auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
-      auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
-
-      GraphDef expected;
-      TF_EXPECT_OK(scope.ToGraphDef(&expected));
-
-      InstantiationResultForTest result;
-      TF_EXPECT_OK(
-          InstantiateFunctionForTest(else_fn.name(), library, &result));
-
-      EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
-      EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
-                result.arg_types);
-      TF_EXPECT_GRAPH_EQ(expected, result.gdef);
-    }
+    TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle));
+    const FunctionBody* body = flr->GetFunctionBody(handle);
+    GraphDef graph_def;
+    body->graph->ToGraphDef(&graph_def);
+    CheckGraphDef(graph_def, library);
+  } else {
+    // Check if graphs were functionalized.
+    CheckGraphDef(optimized_graph_def, library);
+    GraphDef converted_graph_def;
+    graph.ToGraphDef(&converted_graph_def);
+    CheckGraphDef(converted_graph_def, library);
   }
 }
 
-TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
-
-INSTANTIATE_TEST_SUITE_P(
-    FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(),
-    [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
-           info) { return info.param ? "with_filter" : "without_filter"; });
-
 // Returns the names of the "cond" and "body" functions for the While node
 // in a graph.
 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,