Disallow Keras LSTMCell to be used with DropoutWrapper.
See https://github.com/tensorflow/tensorflow/issues/33690 for details.
PiperOrigin-RevId: 277171614
Change-Id: I9925801a2b0c055c595ff20ecf063d6dd876e18e
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
index b9284d4..8e3948c 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
@@ -26,6 +26,7 @@
from tensorflow.python.keras.layers import AbstractRNNCell
+from tensorflow.python.keras.layers import LSTMCell
from tensorflow.python.ops import rnn_cell_wrapper_impl
from tensorflow.python.util.tf_export import tf_export
@@ -96,6 +97,10 @@
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DropoutWrapper, self).__init__(*args, **kwargs)
+ if isinstance(self.cell, LSTMCell):
+ raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
+ "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
+ "instead.")
__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index 9ab88b9..15cbf68 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -225,7 +225,7 @@
def testDropoutWrapperSerialization(self):
wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
- cell = layers.LSTMCell(10)
+ cell = layers.GRUCell(10)
wrapper = wrapper_cls(cell)
config = wrapper.get_config()
@@ -249,6 +249,17 @@
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
+ def testDroputWrapperWithKerasLSTMCell(self):
+ wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
+ cell = layers.LSTMCell(10)
+
+ with self.assertRaisesRegexp(ValueError, "does not work with "):
+ wrapper_cls(cell)
+
+ cell = layers.LSTMCell_v2(10)
+ with self.assertRaisesRegexp(ValueError, "does not work with "):
+ wrapper_cls(cell)
+
if __name__ == "__main__":
test.main()