Add ragged tensor support to tf.nn.experimental.stateless_dropout .
PiperOrigin-RevId: 461087308
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
index bcba7e8..b12d038 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
@@ -153,6 +153,11 @@
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
'rate': 0.5,
'seed': 1},
+ {'op': nn_ops.stateless_dropout,
+ 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
+ 'rate': 0.5,
+ 'seed': [1, 0],
+ 'rng_alg': 'auto_select'},
{'op': math_ops.nextafter,
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
'x2': 0},
diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py
index 51d5723..2374bc9 100644
--- a/tensorflow/python/ops/ragged/ragged_math_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_math_ops.py
@@ -1119,6 +1119,23 @@
nn_ops.dropout_v2(x.flat_values, rate=rate, seed=seed))
+@dispatch.dispatch_for_api(nn_ops.stateless_dropout)
+def stateless_dropout(x: ragged_tensor.Ragged,
+ rate,
+ seed,
+ rng_alg=None,
+ noise_shape=None,
+ name=None):
+ """Ragged dispatch target for tf.nn.experimental.stateless_dropout."""
+ if noise_shape is not None:
+ raise ValueError('noise_shape is not supported yet for RaggedTensor x')
+ with ops.name_scope(name, 'RaggedNNStatelessDropout', [x, rate]):
+ x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
+ return x.with_flat_values(
+ nn_ops.stateless_dropout(
+ x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg))
+
+
#===============================================================================
# Ragged version of Tensor.__eq__ and Tensor.__ne__
#===============================================================================