Add XLA context to the function cache key, so nested tf.function can be supported in TPUStrategy.

PiperOrigin-RevId: 279997866
Change-Id: Ic24ef515fbac27352c92c9b2ca025efbdc2323e0
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 6c807e6..03a26ec 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -54,6 +54,7 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import default_gradient
 from tensorflow.python.ops import functional_ops
@@ -125,8 +126,12 @@
 
 
 CacheKey = collections.namedtuple("CacheKey", [
-    "input_signature", "parent_graph", "device_functions", "colocation_stack",
-    "in_cross_replica_context"
+    "input_signature",
+    "parent_graph",
+    "device_functions",
+    "colocation_stack",
+    "in_cross_replica_context",
+    "xla_context_id",
 ])
 
 
@@ -356,6 +361,23 @@
   return "__inference_%s_%s" % (n, ops.uid())
 
 
+def _enclosing_xla_context():
+  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
+  graph = ops.get_default_graph()
+  while graph is not None:
+    # pylint: disable=protected-access
+    context_ = graph._get_control_flow_context()
+    # pylint: enable=protected-access
+    while context_ is not None:
+      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
+        return context_
+      context_ = context_.outer_context
+    # This may be a FuncGraph due to defuns or v2 control flow. We need to
+    # find the original graph with the XLAControlFlowContext.
+    graph = getattr(graph, "outer_graph", None)
+  return None
+
+
 class _EagerDefinedFunctionDeleter(object):
   """Unregister function from eager context."""
 
@@ -2511,6 +2533,10 @@
         device_functions = (pydev.merge_device(ctx.device_name),)
       else:
         device_functions = ()
+
+      # We should not be in XLA context in eager mode. So always set
+      # `xla_context_id` to 0.
+      xla_context_id = 0
     else:
       colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
       if (uses_distribution_strategy
@@ -2522,6 +2548,14 @@
       else:
         device_functions = ()
 
+      # We want to force function retracing for each different
+      # XLAControlFlowContext, so add `xla_context_id` to the cache key.
+      tpu_context = _enclosing_xla_context()
+      if tpu_context is not None:
+        xla_context_id = id(tpu_context)
+      else:
+        xla_context_id = 0
+
     in_cross_replica_context = False
     try:
       in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
@@ -2529,11 +2563,9 @@
       pass
 
     return CacheKey(
-        _make_input_signature_hashable(input_signature),
-        parent_graph,
-        device_functions,
-        colocation_stack,
-        in_cross_replica_context)
+        _make_input_signature_hashable(input_signature), parent_graph,
+        device_functions, colocation_stack, in_cross_replica_context,
+        xla_context_id)
 
   def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
     """Create a `ConcreteFunction` from `args` and `kwargs`."""