Fix input_lib_type_spec_test when TF2_BEHAVIOR=1.
PiperOrigin-RevId: 322502827
Change-Id: I75d4d7086cbbb33fc12cb04d25e990f30b77897c
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 63c4d27..11cb725e 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -984,6 +984,7 @@
":multi_worker_test_base",
":reduce_util",
":strategy_combinations",
+ ":tpu_strategy",
":values",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py
index 7f5b0e0..691b292 100644
--- a/tensorflow/python/distribute/input_lib_type_spec_test.py
+++ b/tensorflow/python/distribute/input_lib_type_spec_test.py
@@ -27,6 +27,7 @@
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@@ -340,7 +341,17 @@
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
- dist_dataset = distribution.experimental_distribute_dataset(dataset)
+ if isinstance(distribution,
+ (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)):
+ # TPUStrategy does not support distributed datasets with device prefetch
+ # when using sparse or ragged tensors.
+ options = distribute_lib.InputOptions(
+ experimental_prefetch_to_device=False)
+ else:
+ options = None
+
+ dist_dataset = distribution.experimental_distribute_dataset(
+ dataset, options)
with distribution.scope():
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
@@ -395,7 +406,17 @@
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
- dist_dataset = distribution.experimental_distribute_dataset(dataset)
+ if isinstance(distribution,
+ (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)):
+ # TPUStrategy does not support distributed datasets with device prefetch
+ # when using sparse or ragged tensors.
+ options = distribute_lib.InputOptions(
+ experimental_prefetch_to_device=False)
+ else:
+ options = None
+
+ dist_dataset = distribution.experimental_distribute_dataset(
+ dataset, options)
with distribution.scope():
for _ in range(3):
iterator = iter(dist_dataset)