Update keras test to direct import legacy rnn cells.
PiperOrigin-RevId: 371962138
Change-Id: Iaf521eec2bac3510db232ae040e6e657f177c34d
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 7bb29d4..acf4ed5 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -34,11 +34,11 @@
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.layers import recurrent as rnn_v1
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
+from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as variables_lib
@@ -1291,7 +1291,7 @@
recurrent_activation='sigmoid',
implementation=2)
tf_lstm_cell_output = _run_cell(
- rnn_cell.LSTMCell,
+ rnn_cell_impl.LSTMCell,
use_peepholes=True,
initializer=init_ops.ones_initializer)
self.assertNotAllClose(first_implementation_output, no_peephole_output)