Unify V1/V2 layer naming in internal imports.
PiperOrigin-RevId: 285302761
Change-Id: Ib704512d4076487ff39ededc867532917cbebc52
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 87dfa34..07cb1bd 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -18,8 +18,11 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python import tf2
+
# Generic layers.
# pylint: disable=g-bad-import-order
+# pylint: disable=g-import-not-at-top
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.input_spec import InputSpec
@@ -27,10 +30,20 @@
from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
# Preprocessing layers.
-from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
-from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
-from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
-from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
+if tf2.enabled():
+ from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
+ from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
+ NormalizationV2 = Normalization
+ from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
+ from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
+ TextVectorizationV2 = TextVectorization
+else:
+ from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
+ from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
+ NormalizationV1 = Normalization
+ from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
+ from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
+ TextVectorizationV1 = TextVectorization
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
@@ -121,8 +134,14 @@
# Normalization layers.
from tensorflow.python.keras.layers.normalization import LayerNormalization
-from tensorflow.python.keras.layers.normalization import BatchNormalization
-from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
+if tf2.enabled():
+ from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
+ from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
+ BatchNormalizationV2 = BatchNormalization
+else:
+ from tensorflow.python.keras.layers.normalization import BatchNormalization
+ from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
+ BatchNormalizationV1 = BatchNormalization
# Kernelized layers.
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
@@ -163,14 +182,32 @@
from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
from tensorflow.python.keras.layers.recurrent import SimpleRNN
-from tensorflow.python.keras.layers.recurrent import GRU
-from tensorflow.python.keras.layers.recurrent import GRUCell
-from tensorflow.python.keras.layers.recurrent import LSTM
-from tensorflow.python.keras.layers.recurrent import LSTMCell
-from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRU_v2
-from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCell_v2
-from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTM_v2
-from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCell_v2
+if tf2.enabled():
+ from tensorflow.python.keras.layers.recurrent_v2 import GRU
+ from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
+ from tensorflow.python.keras.layers.recurrent_v2 import LSTM
+ from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
+ from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
+ from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
+ from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
+ from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
+ GRUV2 = GRU
+ GRUCellV2 = GRUCell
+ LSTMV2 = LSTM
+ LSTMCellV2 = LSTMCell
+else:
+ from tensorflow.python.keras.layers.recurrent import GRU
+ from tensorflow.python.keras.layers.recurrent import GRUCell
+ from tensorflow.python.keras.layers.recurrent import LSTM
+ from tensorflow.python.keras.layers.recurrent import LSTMCell
+ from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
+ from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
+ from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
+ from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
+ GRUV1 = GRU
+ GRUCellV1 = GRUCell
+ LSTMV1 = LSTM
+ LSTMCellV1 = LSTMCell
# Convolutional-recurrent layers.
from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index e3e193c..1c20918 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -460,7 +460,7 @@
input_shape = (3, 5)
def gru(cudnn=False, **kwargs):
- layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU
+ layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRUV1
return layer_class(2, input_shape=input_shape, **kwargs)
def get_layer_weights(layer):
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index 15cbf68..a01e56b 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -256,7 +256,7 @@
with self.assertRaisesRegexp(ValueError, "does not work with "):
wrapper_cls(cell)
- cell = layers.LSTMCell_v2(10)
+ cell = layers.LSTMCellV2(10)
with self.assertRaisesRegexp(ValueError, "does not work with "):
wrapper_cls(cell)
diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py
index dc97b27..28101cf 100644
--- a/tensorflow/python/keras/saving/hdf5_format_test.py
+++ b/tensorflow/python/keras/saving/hdf5_format_test.py
@@ -130,7 +130,7 @@
(None, input_dim, 4, 4, 4),
],
[
- (keras.layers.GRU(output_dim)),
+ (keras.layers.GRUV1(output_dim)),
[np.random.random((input_dim, output_dim)),
np.random.random((output_dim, output_dim)),
np.random.random((output_dim,)),
@@ -143,7 +143,7 @@
(None, 4, input_dim),
],
[
- (keras.layers.LSTM(output_dim)),
+ (keras.layers.LSTMV1(output_dim)),
[np.random.random((input_dim, output_dim)),
np.random.random((output_dim, output_dim)),
np.random.random((output_dim,)),
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 9bf7aef..98f0426 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -20,7 +20,7 @@
from __future__ import print_function
-from tensorflow.python.keras import layers as keras_layers
+from tensorflow.python.keras.layers import normalization as keras_normalization
from tensorflow.python.layers import base
from tensorflow.python.ops import init_ops
from tensorflow.python.util import deprecation
@@ -28,7 +28,7 @@
@tf_export(v1=['layers.BatchNormalization'])
-class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
+class BatchNormalization(keras_normalization.BatchNormalization, base.Layer):
"""Batch Normalization layer from (Ioffe et al., 2015).
Keras APIs handle BatchNormalization updates to the moving_mean and
@@ -175,7 +175,7 @@
@deprecation.deprecated(
date=None, instructions='Use keras.layers.BatchNormalization instead. In '
'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not '
- 'be used (consult the `tf.keras.layers.batch_normalization` '
+ 'be used (consult the `tf.keras.layers.BatchNormalization` '
'documentation).')
@tf_export(v1=['layers.batch_normalization'])
def batch_normalization(inputs,