Revert the change to still override the _XlaScope when auto_jit is on.
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 486f60e..91423f6 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -925,11 +925,16 @@
absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
// Look for an _XlaScope on both nodes. If both nodes have a scope and the
- // scopes do not match, do not cluster along this edge. If even one of the
- // nodes lacks an _XlaScope attribute, then it is treated as a "bridge" and
- // a cluster may be created along it. We may want to restrict this behavior
- // to require all nodes marked with _XlaCompile=true to also have a _XlaScope
- // property set (and raise an error otherwise); but for now we don't do this.
+ // scopes do not match, do not cluster along this edge. This restriction is
+ // overridden if the global_jit_level_ is ON. If even one of the nodes lacks
+ // an _XlaScope attribute, then it is treated as a "bridge" and a cluster may
+ // be created along it. We may want to restrict this behavior to require all
+ // nodes marked with _XlaCompile=true to also have a _XlaScope property set
+ // (and raise an error otherwise); but for now we don't do this.
+ if (global_jit_level_ != OptimizerOptions::OFF) {
+ return absl::nullopt;
+ }
+
string scope;
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
return scope;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9f6bdcc..cbe60b0 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -439,7 +439,7 @@
EXPECT_EQ(0, clusters.size());
}
-TEST(XlaCompilationTest, CyclesWithAllDifferentScopesRespectedByGlobalJit) {
+TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
@@ -465,8 +465,11 @@
// The computation is: C = A + relu(A)
// where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
- // In this case, the GlobalJitLevel respects the scopes to cluster.
- EXPECT_EQ(0, clusters.size());
+ // In this case, the GlobalJitLevel overrides the scopes to cluster while
+ // ignoring scopes.
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {