Disallow Keras LSTMCell to be used with DropoutWrapper.

See https://github.com/tensorflow/tensorflow/issues/33690 for details.

PiperOrigin-RevId: 277171614
Change-Id: I9925801a2b0c055c595ff20ecf063d6dd876e18e
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
index b9284d4..8e3948c 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
@@ -26,6 +26,7 @@
 
 
 from tensorflow.python.keras.layers import AbstractRNNCell
+from tensorflow.python.keras.layers import LSTMCell
 from tensorflow.python.ops import rnn_cell_wrapper_impl
 from tensorflow.python.util.tf_export import tf_export
 
@@ -96,6 +97,10 @@
 
   def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
     super(DropoutWrapper, self).__init__(*args, **kwargs)
+    if isinstance(self.cell, LSTMCell):
+      raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
+                       "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
+                       "instead.")
 
   __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
 
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index 9ab88b9..15cbf68 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -225,7 +225,7 @@
 
   def testDropoutWrapperSerialization(self):
     wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
-    cell = layers.LSTMCell(10)
+    cell = layers.GRUCell(10)
     wrapper = wrapper_cls(cell)
     config = wrapper.get_config()
 
@@ -249,6 +249,17 @@
     reconstructed_wrapper = wrapper_cls.from_config(config)
     self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
 
+  def testDroputWrapperWithKerasLSTMCell(self):
+    wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
+    cell = layers.LSTMCell(10)
+
+    with self.assertRaisesRegexp(ValueError, "does not work with "):
+      wrapper_cls(cell)
+
+    cell = layers.LSTMCell_v2(10)
+    with self.assertRaisesRegexp(ValueError, "does not work with "):
+      wrapper_cls(cell)
+
 
 if __name__ == "__main__":
   test.main()