Make Grappler also ignore functions transitively called by XlaLaunch ops
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 2f1c869..02ed3e2 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -661,18 +661,41 @@
     find_differentiable_functions(function.node_def());
   }
 
-  // Find functions that are formed by XLA and will be compiled later. We do it
-  // by looking for a function attribute in XlaLaunch ops. Grappler rewrites
-  // potentially can add nodes that are not supported by XLA, so we choose to
-  // skip such functions when we optimize function library.
+  // Find functions that will be compiled by XLA later
+  // We do it by looking for XlaLaunch ops that call functions,
+  // then depth first search down those functions to find transitive functions.
+  // Grappler rewrites can potentially add nodes that are
+  // not supported by XLA, so we choose to skip such functions when we optimize
+  // the function library.
   absl::flat_hash_set<string> xla_compiled_functions;
+  std::function<void(const string&)> find_all_functions;
+  find_all_functions = [&](const string& func) -> void {
+    // Ignore call cycles in the graph
+    if (xla_compiled_functions.contains(func)) return;
+    // Find func in the flib
+    const FunctionDef* func_def = flib.Find(func);
+    CHECK(func_def) << "not found: " << func;
+    // Mark function to be ignored by grappler
+    xla_compiled_functions.insert(func);
+    // Depth first search through the func for transitively called funcs
+    for (const NodeDef& node : func_def->node_def()) {
+      for (const auto attr : node.attr()) {
+        const AttrValue& attr_value = attr.second;
+        if (attr_value.has_func()) {
+          find_all_functions(attr_value.func().name());
+        }
+      }
+    }
+  };
 
-  const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
+  auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
     NameAttrList function;
     for (const NodeDef& node : nodes) {
+      // Look only for XlaLaunch nodes that call a function
       if (!IsXlaLaunch(node)) continue;
       if (!GetNodeAttr(node, "function", &function).ok()) continue;
-      xla_compiled_functions.insert(function.name());
+      // Find all transitively called functions
+      find_all_functions(function.name());
     }
   };