blob: 962b557d7a93c4fbb9958518e28c57a5591376ed [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests deprecation warnings in a few special cases."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import module_wrapper
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
class DeprecationTest(test.TestCase):
@test.mock.patch.object(logging, "warning", autospec=True)
def testDeprecatedFunction(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.compat.v1.initializers.tables_initializer()
self.assertEqual(0, mock_warning.call_count)
tf.tables_initializer()
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"tables_initializer")
self.assertRegexpMatches(
mock_warning.call_args[0][3],
r"compat.v1.tables_initializer")
tf.tables_initializer()
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testDeprecatedClass(self, mock_warning):
value = np.array([1, 2, 3])
row_splits = np.array([1])
self.assertEqual(0, mock_warning.call_count)
tf.compat.v1.ragged.RaggedTensorValue(value, row_splits)
self.assertEqual(0, mock_warning.call_count)
tf.ragged.RaggedTensorValue(value, row_splits)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
self.assertRegexpMatches(
mock_warning.call_args[0][3],
r"compat.v1.ragged.RaggedTensorValue")
tf.ragged.RaggedTensorValue(value, row_splits)
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testDeprecatedFunctionEndpoint(self, mock_warning):
array = tf.IndexedSlices(
tf.compat.v1.convert_to_tensor(np.array([1, 2])),
tf.compat.v1.convert_to_tensor(np.array([0, 2])))
mask_indices = tf.compat.v1.convert_to_tensor(np.array([2]))
self.assertEqual(0, mock_warning.call_count)
tf.sparse.mask(array, mask_indices)
self.assertEqual(0, mock_warning.call_count)
tf.sparse_mask(array, mask_indices)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"sparse_mask")
self.assertRegexpMatches(
mock_warning.call_args[0][3],
"sparse.mask")
tf.sparse_mask(array, mask_indices)
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testDeprecatedClassEndpoint(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.io.VarLenFeature(tf.dtypes.int32)
self.assertEqual(0, mock_warning.call_count)
tf.VarLenFeature(tf.dtypes.int32)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"VarLenFeature")
self.assertRegexpMatches(
mock_warning.call_args[0][3],
r"io.VarLenFeature")
tf.VarLenFeature(tf.dtypes.int32)
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testDeprecatedConstantEndpoint(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
self.assertEqual(0, mock_warning.call_count)
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2],
r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
self.assertRegexpMatches(
mock_warning.call_args[0][3],
r"saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testKerasDeprecationNoWarning(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.keras.layers.GRUCell(20)
self.assertLessEqual(mock_warning.call_count, 1)
if mock_warning.call_count == 1:
# The only message printed should be due to referencing init op.
self.assertRegexpMatches(
mock_warning.call_args[0][-1],
"Call initializer instance with the dtype argument instead of "
"passing it to the constructor")
@test.mock.patch.object(logging, "warning", autospec=True)
def testKerasDeprecation(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.keras.backend.get_session()
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][-1],
"tf.compat.v1.keras.backend.get_session")
tf.keras.backend.get_session()
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testKerasEndpointDeprecation(self, mock_warning):
self.assertEqual(0, mock_warning.call_count)
tf.keras.metrics.cosine_proximity([0.5], [0.5])
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][-1],
"tf.keras.losses.cosine_similarity")
tf.keras.metrics.cosine_proximity([0.5], [0.5])
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def testEstimatorDeprecation(self, mock_warning):
if "KMeans" in tf.estimator.experimental.__dict__:
self.assertEqual(0, mock_warning.call_count)
tf.estimator.experimental.KMeans(2)
self.assertEqual(2, mock_warning.call_count)
# First message is not a deprecation warning.
self.assertRegexpMatches(
mock_warning.call_args_list[1][0][0],
"Using temporary folder as model directory:")
# Second message is a deprecation warning.
self.assertRegexpMatches(
mock_warning.call_args_list[0][0][-1],
"tf.compat.v1.estimator.experimental.KMeans")
if __name__ == "__main__":
test.main()