Make Grappler also ignore functions transitively called by XlaLaunch ops
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 2f1c869..02ed3e2 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -661,18 +661,41 @@
find_differentiable_functions(function.node_def());
}
- // Find functions that are formed by XLA and will be compiled later. We do it
- // by looking for a function attribute in XlaLaunch ops. Grappler rewrites
- // potentially can add nodes that are not supported by XLA, so we choose to
- // skip such functions when we optimize function library.
+ // Find functions that will be compiled by XLA later
+ // We do it by looking for XlaLaunch ops that call functions,
+ // then depth first search down those functions to find transitive functions.
+ // Grappler rewrites can potentially add nodes that are
+ // not supported by XLA, so we choose to skip such functions when we optimize
+ // the function library.
absl::flat_hash_set<string> xla_compiled_functions;
+ std::function<void(const string&)> find_all_functions;
+ find_all_functions = [&](const string& func) -> void {
+ // Ignore call cycles in the graph
+ if (xla_compiled_functions.contains(func)) return;
+ // Find func in the flib
+ const FunctionDef* func_def = flib.Find(func);
+ CHECK(func_def) << "not found: " << func;
+ // Mark function to be ignored by grappler
+ xla_compiled_functions.insert(func);
+ // Depth first search through the func for transitively called funcs
+ for (const NodeDef& node : func_def->node_def()) {
+ for (const auto attr : node.attr()) {
+ const AttrValue& attr_value = attr.second;
+ if (attr_value.has_func()) {
+ find_all_functions(attr_value.func().name());
+ }
+ }
+ }
+ };
- const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
+ auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
NameAttrList function;
for (const NodeDef& node : nodes) {
+ // Look only for XlaLaunch nodes that call a function
if (!IsXlaLaunch(node)) continue;
if (!GetNodeAttr(node, "function", &function).ok()) continue;
- xla_compiled_functions.insert(function.name());
+ // Find all transitively called functions
+ find_all_functions(function.name());
}
};