[TF2XLA] Inject XLAControlFlowContext in ConcreteFunction, as it can be called directly, bypassing Function
PiperOrigin-RevId: 326322743
Change-Id: I7df50c5c650b3cd317e360f127ac1b782a14fbaf
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index ba75aed..f1e25c0 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -220,6 +220,8 @@
self.assertAllClose(40.0, f(2.0))
self.assertAllClose([40.0, 28.0], g(2.0))
+ self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))
+ self.assertAllClose([40.0, 28.0], g.get_concrete_function(2.0)(2.0))
def testMethodCompilation(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f0004f0..bb4449a 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -57,6 +57,7 @@
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import functional_ops
@@ -1939,15 +1940,24 @@
possible_gradient_type,
executing_eagerly)
forward_function, args_with_tangents = forward_backward.forward()
- if executing_eagerly:
- flat_outputs = forward_function.call(
- ctx, args_with_tangents,
- cancellation_manager=cancellation_manager)
- else:
- with default_graph._override_gradient_function( # pylint: disable=protected-access
- {"PartitionedCall": self._get_gradient_function(),
- "StatefulPartitionedCall": self._get_gradient_function()}):
- flat_outputs = forward_function.call(ctx, args_with_tangents)
+ compiled_with_xla = self._attrs.get("_XlaMustCompile", False) and \
+ not control_flow_util.GraphOrParentsInXlaContext(default_graph)
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ try:
+ if compiled_with_xla:
+ xla_context.Enter()
+ if executing_eagerly:
+ flat_outputs = forward_function.call(
+ ctx, args_with_tangents,
+ cancellation_manager=cancellation_manager)
+ else:
+ with default_graph._override_gradient_function( # pylint: disable=protected-access
+ {"PartitionedCall": self._get_gradient_function(),
+ "StatefulPartitionedCall": self._get_gradient_function()}):
+ flat_outputs = forward_function.call(ctx, args_with_tangents)
+ finally:
+ if compiled_with_xla:
+ xla_context.Exit()
forward_backward.record(flat_outputs)
return self._build_call_outputs(flat_outputs)