Add XLA context to the function cache key, so nested tf.function can be supported in TPUStrategy.
PiperOrigin-RevId: 279997866
Change-Id: Ic24ef515fbac27352c92c9b2ca025efbdc2323e0
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 6c807e6..03a26ec 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -54,6 +54,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import functional_ops
@@ -125,8 +126,12 @@
CacheKey = collections.namedtuple("CacheKey", [
- "input_signature", "parent_graph", "device_functions", "colocation_stack",
- "in_cross_replica_context"
+ "input_signature",
+ "parent_graph",
+ "device_functions",
+ "colocation_stack",
+ "in_cross_replica_context",
+ "xla_context_id",
])
@@ -356,6 +361,23 @@
return "__inference_%s_%s" % (n, ops.uid())
+def _enclosing_xla_context():
+ """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
+ graph = ops.get_default_graph()
+ while graph is not None:
+ # pylint: disable=protected-access
+ context_ = graph._get_control_flow_context()
+ # pylint: enable=protected-access
+ while context_ is not None:
+ if isinstance(context_, control_flow_ops.XLAControlFlowContext):
+ return context_
+ context_ = context_.outer_context
+ # This may be a FuncGraph due to defuns or v2 control flow. We need to
+ # find the original graph with the XLAControlFlowContext.
+ graph = getattr(graph, "outer_graph", None)
+ return None
+
+
class _EagerDefinedFunctionDeleter(object):
"""Unregister function from eager context."""
@@ -2511,6 +2533,10 @@
device_functions = (pydev.merge_device(ctx.device_name),)
else:
device_functions = ()
+
+ # We should not be in XLA context in eager mode. So always set
+ # `xla_context_id` to 0.
+ xla_context_id = 0
else:
colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
if (uses_distribution_strategy
@@ -2522,6 +2548,14 @@
else:
device_functions = ()
+ # We want to force function retracing for each different
+ # XLAControlFlowContext, so add `xla_context_id` to the cache key.
+ tpu_context = _enclosing_xla_context()
+ if tpu_context is not None:
+ xla_context_id = id(tpu_context)
+ else:
+ xla_context_id = 0
+
in_cross_replica_context = False
try:
in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
@@ -2529,11 +2563,9 @@
pass
return CacheKey(
- _make_input_signature_hashable(input_signature),
- parent_graph,
- device_functions,
- colocation_stack,
- in_cross_replica_context)
+ _make_input_signature_hashable(input_signature), parent_graph,
+ device_functions, colocation_stack, in_cross_replica_context,
+ xla_context_id)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""