Fix input_lib_type_spec_test when TF2_BEHAVIOR=1.

PiperOrigin-RevId: 322502827
Change-Id: I75d4d7086cbbb33fc12cb04d25e990f30b77897c
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 63c4d27..11cb725e 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -984,6 +984,7 @@
         ":multi_worker_test_base",
         ":reduce_util",
         ":strategy_combinations",
+        ":tpu_strategy",
         ":values",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:errors",
diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py
index 7f5b0e0..691b292 100644
--- a/tensorflow/python/distribute/input_lib_type_spec_test.py
+++ b/tensorflow/python/distribute/input_lib_type_spec_test.py
@@ -27,6 +27,7 @@
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.distribute import values
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
@@ -340,7 +341,17 @@
     distribution.extended.experimental_enable_get_next_as_optional = (
         enable_get_next_as_optional)
 
-    dist_dataset = distribution.experimental_distribute_dataset(dataset)
+    if isinstance(distribution,
+                  (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)):
+      # TPUStrategy does not support distributed datasets with device prefetch
+      # when using sparse or ragged tensors.
+      options = distribute_lib.InputOptions(
+          experimental_prefetch_to_device=False)
+    else:
+      options = None
+
+    dist_dataset = distribution.experimental_distribute_dataset(
+        dataset, options)
     with distribution.scope():
       iterator = iter(dist_dataset)
       _check_type_spec_structure(iterator)
@@ -395,7 +406,17 @@
     distribution.extended.experimental_enable_get_next_as_optional = (
         enable_get_next_as_optional)
 
-    dist_dataset = distribution.experimental_distribute_dataset(dataset)
+    if isinstance(distribution,
+                  (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)):
+      # TPUStrategy does not support distributed datasets with device prefetch
+      # when using sparse or ragged tensors.
+      options = distribute_lib.InputOptions(
+          experimental_prefetch_to_device=False)
+    else:
+      options = None
+
+    dist_dataset = distribution.experimental_distribute_dataset(
+        dataset, options)
     with distribution.scope():
       for _ in range(3):
         iterator = iter(dist_dataset)