Add a unit test for single core TPU jit compile with outside compilation.
This test currently fails but should work in the future once the TF2XLA rewrite passes are enabled for this flow.
PiperOrigin-RevId: 432481349
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index 5642040..25cf440 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -45,9 +45,11 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import flags
@@ -113,6 +115,40 @@
result = foo(a)
self.assertAllEqual(6, result)
+ # In this case, the entire computation in foo is compiled using JIT
+ # compilation and contains unsupported ops that should be outside compiled.
+ def test_single_tpu_jit_compile_with_outside_compilation(self):
+ if FLAGS.tpu_use_tfrt:
+ self.skipTest(
+ "This test triggers _XlaCompile and XlaLaunch which are not "
+ "supported in tfrt yet. We should avoid using these kernels on TPU. "
+ "However, it is a workaround to support b/129842431. We need more "
+ "discussion about how to support it in the long term.")
+ config.set_soft_device_placement(True)
+ with ops.device("/device:TPU:0"):
+ a = variables.Variable(1)
+
+ def get_a_plus_one():
+ return a + 1
+
+ @def_function.function(
+ input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
+ def foo(x):
+ b = x + get_a_plus_one()
+ my_str = string_ops.as_string(b)
+ new_str = my_str + "0"
+ c = string_ops.string_to_number(new_str, out_type=dtypes.int32)
+ logging_ops.print_v2(c)
+ b = c + get_a_plus_one()
+ return b + 1
+
+ # TODO(b/222338429): Replace this assert once outside compilation is
+ # supported with jit_compile.
+ with self.assertRaises(errors.InvalidArgumentError):
+ with ops.device("/device:TPU:0"):
+ foo(a)
+ # self.assertAllEqual(6, result)
+
# In this case, each of the ops in the TPU device scope are compiled and run
# individually.
def test_single_tpu_on_demand(self):