Make TPUStrategy work with tf.function(experimental_compile=True). This involves two changes:

1. Only create replicated var handle inside TPUReplicateContext.
2. If the function annotated with experimental_compile=True is called inside a XLAControlFlowContext, don't create a new XLAControlFlowContext.

PiperOrigin-RevId: 296086034
Change-Id: I821f3b3cd5ba69cd4c7bdb9c28e13e4b4c83f967
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index bc6865c..a4e2795 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -620,6 +620,7 @@
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/tpu:tpu_lib",
         "//tensorflow/python/training/tracking:base",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py
index dcce40a..6fafa43 100644
--- a/tensorflow/python/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_models_test.py
@@ -356,6 +356,50 @@
 
     @def_function.function
     def train_step(iterator):
+
+      def step_fn(inputs):
+        images, targets = inputs
+        with backprop.GradientTape() as tape:
+          outputs = model(images)
+          loss = math_ops.reduce_sum(outputs - targets)
+        grads = tape.gradient(loss, model.variables)
+        return grads
+
+      outputs = distribution.experimental_run_v2(
+          step_fn, args=(next(iterator),))
+      return nest.map_structure(distribution.experimental_local_results,
+                                outputs)
+
+    train_step(input_iterator)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
+  def test_tf_function_experimental_compile(self, distribution):
+    dataset = self._get_dataset()
+    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
+
+    class CustomDense(keras.layers.Layer):
+
+      def __init__(self, num_outputs):
+        super(CustomDense, self).__init__()
+        self.num_outputs = num_outputs
+
+      def build(self, input_shape):
+        self.kernel = self.add_variable(
+            "kernel", shape=[int(input_shape[-1]), self.num_outputs])
+
+      @def_function.function(experimental_compile=True)
+      def call(self, inputs):
+        return math_ops.matmul(inputs, self.kernel)
+
+    with distribution.scope():
+      x = keras.layers.Input(shape=(3,))
+      y = CustomDense(4)(x)
+      model = keras.Model(x, y)
+
+    @def_function.function
+    def train_step(iterator):
       def step_fn(inputs):
         images, targets = inputs
         with backprop.GradientTape() as tape:
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index baf3b82..74e9c60 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -38,6 +38,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.tpu import tpu
 from tensorflow.python.training import saver
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import nest
@@ -938,14 +939,14 @@
 
 
 def _enclosing_tpu_context():
-  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
+  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
   graph = ops.get_default_graph()
   while graph is not None:
     # pylint: disable=protected-access
     context_ = graph._get_control_flow_context()
     # pylint: enable=protected-access
     while context_ is not None:
-      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
+      if isinstance(context_, tpu.TPUReplicateContext):
         return context_
       context_ = context_.outer_context
     # This may be a FuncGraph due to defuns or v2 control flow. We need to
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 65d0784..7aef5da 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -689,6 +689,7 @@
         ":lift_to_graph",
         "//tensorflow/python:cond_v2",  # TODO(b/118513001): Imported via control_flow_ops; remove.
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_util",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:util",
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index a2bcb91..76af2d3 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -31,6 +31,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import tf_logging as logging
@@ -563,9 +564,12 @@
       return self._python_function(*args, **kwds)
 
     tracing_count = self._get_tracing_count()
-    if self._experimental_compile:
+    if self._experimental_compile and (
+        not control_flow_util.GraphOrParentsInXlaContext(
+            ops.get_default_graph())):
       # V2 control flow relies on XLAControlFlowContext to generate a
-      # XLA-compatible function graph.
+      # XLA-compatible function graph. If the function is already called inside
+      # an XLA context, we don't create nested XLA context.
       xla_context = control_flow_ops.XLAControlFlowContext()
       try:
         xla_context.Enter()