[TF2XLA] Inject XLAControlFlowContext in ConcreteFunction, as it can be called directly, bypassing Function

PiperOrigin-RevId: 326322743
Change-Id: I7df50c5c650b3cd317e360f127ac1b782a14fbaf
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index ba75aed..f1e25c0 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -220,6 +220,8 @@
 
       self.assertAllClose(40.0, f(2.0))
       self.assertAllClose([40.0, 28.0], g(2.0))
+      self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))
+      self.assertAllClose([40.0, 28.0], g.get_concrete_function(2.0)(2.0))
 
   def testMethodCompilation(self):
 
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f0004f0..bb4449a 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -57,6 +57,7 @@
 from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import default_gradient
 from tensorflow.python.ops import functional_ops
@@ -1939,15 +1940,24 @@
         possible_gradient_type,
         executing_eagerly)
     forward_function, args_with_tangents = forward_backward.forward()
-    if executing_eagerly:
-      flat_outputs = forward_function.call(
-          ctx, args_with_tangents,
-          cancellation_manager=cancellation_manager)
-    else:
-      with default_graph._override_gradient_function(  # pylint: disable=protected-access
-          {"PartitionedCall": self._get_gradient_function(),
-           "StatefulPartitionedCall": self._get_gradient_function()}):
-        flat_outputs = forward_function.call(ctx, args_with_tangents)
+    compiled_with_xla = self._attrs.get("_XlaMustCompile", False) and \
+        not control_flow_util.GraphOrParentsInXlaContext(default_graph)
+    xla_context = control_flow_ops.XLAControlFlowContext()
+    try:
+      if compiled_with_xla:
+        xla_context.Enter()
+      if executing_eagerly:
+        flat_outputs = forward_function.call(
+            ctx, args_with_tangents,
+            cancellation_manager=cancellation_manager)
+      else:
+        with default_graph._override_gradient_function(  # pylint: disable=protected-access
+            {"PartitionedCall": self._get_gradient_function(),
+             "StatefulPartitionedCall": self._get_gradient_function()}):
+          flat_outputs = forward_function.call(ctx, args_with_tangents)
+    finally:
+      if compiled_with_xla:
+        xla_context.Exit()
     forward_backward.record(flat_outputs)
     return self._build_call_outputs(flat_outputs)