blob: 2b1065e16aea4648497d2b556237dc1594271287 [file] [log] [blame]
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""TensorFlow API compatibility tests.
This test ensures all changes to the public API of TensorFlow are intended.
If this test fails, it means a change has been made to the public API. Backwards
incompatible changes are not allowed. You can run the test with
"--update_goldens" flag set to "True" to update goldens when making changes to
the public TF python API.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import re
import sys
import six
from six.moves import range
import tensorflow as tf
from google.protobuf import message
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.tools.api.lib import api_objects_pb2
from tensorflow.tools.api.lib import python_object_to_proto_visitor
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
# pylint: disable=g-import-not-at-top,unused-import
_TENSORBOARD_AVAILABLE = True
try:
import tensorboard as _tb
except ImportError:
_TENSORBOARD_AVAILABLE = False
# pylint: enable=g-import-not-at-top,unused-import
# FLAGS defined at the bottom:
FLAGS = None
# DEFINE_boolean, update_goldens, default False:
_UPDATE_GOLDENS_HELP = """
Update stored golden files if API is updated. WARNING: All API changes
have to be authorized by TensorFlow leads.
"""
# DEFINE_boolean, only_test_core_api, default False:
_ONLY_TEST_CORE_API_HELP = """
Some TF APIs are being moved outside of the tensorflow/ directory. There is
no guarantee which versions of these APIs will be present when running this
test. Therefore, do not error out on API changes in non-core TF code
if this flag is set.
"""
# DEFINE_boolean, verbose_diffs, default True:
_VERBOSE_DIFFS_HELP = """
If set to true, print line by line diffs on all libraries. If set to
false, only print which libraries have differences.
"""
_API_GOLDEN_FOLDER_V1 = resource_loader.get_path_to_datafile('../golden/v1')
_API_GOLDEN_FOLDER_V2 = resource_loader.get_path_to_datafile('../golden/v2')
_TEST_README_FILE = resource_loader.get_path_to_datafile('README.txt')
_UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
'API_UPDATE_WARNING.txt')
_NON_CORE_PACKAGES = ['estimator']
# TODO(annarev): remove this once we test with newer version of
# estimator that actually has compat v1 version.
if not hasattr(tf.compat.v1, 'estimator'):
tf.compat.v1.estimator = tf.estimator
tf.compat.v2.estimator = tf.estimator
def _KeyToFilePath(key, api_version):
"""From a given key, construct a filepath.
Filepath will be inside golden folder for api_version.
Args:
key: a string used to determine the file path
api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
Returns:
A string of file path to the pbtxt file which describes the public API
"""
def _ReplaceCapsWithDash(matchobj):
match = matchobj.group(0)
return '-%s' % (match.lower())
case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash,
six.ensure_str(key))
api_folder = (
_API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1)
return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key)
def _FileNameToKey(filename):
"""From a given filename, construct a key we use for api objects."""
def _ReplaceDashWithCaps(matchobj):
match = matchobj.group(0)
return match[1].upper()
base_filename = os.path.basename(filename)
base_filename_without_ext = os.path.splitext(base_filename)[0]
api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps,
six.ensure_str(base_filename_without_ext))
return api_object_key
def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
"""A Visitor that crashes on subclasses of generated proto classes."""
# If the traversed object is a proto Message class
if not (isinstance(parent, type) and issubclass(parent, message.Message)):
return
if parent is message.Message:
return
# Check that it is a direct subclass of Message.
if message.Message not in parent.__bases__:
raise NotImplementedError(
'Object tf.%s is a subclass of a generated proto Message. '
'They are not yet supported by the API tools.' % path)
def _FilterNonCoreGoldenFiles(golden_file_list):
"""Filter out non-core API pbtxt files."""
filtered_file_list = []
filtered_package_prefixes = ['tensorflow.%s.' % p for p in _NON_CORE_PACKAGES]
for f in golden_file_list:
if any(
six.ensure_str(f).rsplit('/')[-1].startswith(pre)
for pre in filtered_package_prefixes):
continue
filtered_file_list.append(f)
return filtered_file_list
def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map):
"""Filter out golden proto dict symbols that should be omitted."""
if not omit_golden_symbols_map:
return golden_proto_dict
filtered_proto_dict = dict(golden_proto_dict)
for key, symbol_list in six.iteritems(omit_golden_symbols_map):
api_object = api_objects_pb2.TFAPIObject()
api_object.CopyFrom(filtered_proto_dict[key])
filtered_proto_dict[key] = api_object
module_or_class = None
if api_object.HasField('tf_module'):
module_or_class = api_object.tf_module
elif api_object.HasField('tf_class'):
module_or_class = api_object.tf_class
if module_or_class is not None:
for members in (module_or_class.member, module_or_class.member_method):
filtered_members = [m for m in members if m.name not in symbol_list]
# Two steps because protobuf repeated fields disallow slice assignment.
del members[:]
members.extend(filtered_members)
return filtered_proto_dict
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
super(ApiCompatibilityTest, self).__init__(*args, **kwargs)
golden_update_warning_filename = os.path.join(
resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE)
self._update_golden_warning = file_io.read_file_to_string(
golden_update_warning_filename)
test_readme_filename = os.path.join(
resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE)
self._test_readme_message = file_io.read_file_to_string(
test_readme_filename)
def _AssertProtoDictEquals(self,
expected_dict,
actual_dict,
verbose=False,
update_goldens=False,
additional_missing_object_message='',
api_version=2):
"""Diff given dicts of protobufs and report differences a readable way.
Args:
expected_dict: a dict of TFAPIObject protos constructed from golden files.
actual_dict: a ict of TFAPIObject protos constructed by reading from the
TF package linked to the test.
verbose: Whether to log the full diffs, or simply report which files were
different.
update_goldens: Whether to update goldens when there are diffs found.
additional_missing_object_message: Message to print when a symbol is
missing.
api_version: TensorFlow API version to test.
"""
diffs = []
verbose_diffs = []
expected_keys = set(expected_dict.keys())
actual_keys = set(actual_dict.keys())
only_in_expected = expected_keys - actual_keys
only_in_actual = actual_keys - expected_keys
all_keys = expected_keys | actual_keys
# This will be populated below.
updated_keys = []
for key in all_keys:
diff_message = ''
verbose_diff_message = ''
# First check if the key is not found in one or the other.
if key in only_in_expected:
diff_message = 'Object %s expected but not found (removed). %s' % (
key, additional_missing_object_message)
verbose_diff_message = diff_message
elif key in only_in_actual:
diff_message = 'New object %s found (added).' % key
verbose_diff_message = diff_message
else:
# Do not truncate diff
self.maxDiff = None # pylint: disable=invalid-name
# Now we can run an actual proto diff.
try:
self.assertProtoEquals(expected_dict[key], actual_dict[key])
except AssertionError as e:
updated_keys.append(key)
diff_message = 'Change detected in python object: %s.' % key
verbose_diff_message = str(e)
# All difference cases covered above. If any difference found, add to the
# list.
if diff_message:
diffs.append(diff_message)
verbose_diffs.append(verbose_diff_message)
# If diffs are found, handle them based on flags.
if diffs:
diff_count = len(diffs)
logging.error(self._test_readme_message)
logging.error('%d differences found between API and golden.', diff_count)
messages = verbose_diffs if verbose else diffs
for i in range(diff_count):
print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)
if update_goldens:
# Write files if requested.
logging.warning(self._update_golden_warning)
# If the keys are only in expected, some objects are deleted.
# Remove files.
for key in only_in_expected:
filepath = _KeyToFilePath(key, api_version)
file_io.delete_file(filepath)
# If the files are only in actual (current library), these are new
# modules. Write them to files. Also record all updates in files.
for key in only_in_actual | set(updated_keys):
filepath = _KeyToFilePath(key, api_version)
file_io.write_string_to_file(
filepath, text_format.MessageToString(actual_dict[key]))
else:
# Fail if we cannot fix the test by updating goldens.
self.fail('%d differences found between API and golden.' % diff_count)
else:
logging.info('No differences found between API and golden.')
def testNoSubclassOfMessage(self):
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
# Skip compat.v1 and compat.v2 since they are validated in separate tests.
visitor.private_map['tf.compat'] = ['v1', 'v2']
traverse.traverse(tf, visitor)
def testNoSubclassOfMessageV1(self):
if not hasattr(tf.compat, 'v1'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
if FLAGS.only_test_core_api:
visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
visitor.private_map['tf.compat'] = ['v1', 'v2']
traverse.traverse(tf.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
if FLAGS.only_test_core_api:
visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
visitor.private_map['tf.compat'] = ['v1', 'v2']
traverse.traverse(tf.compat.v2, visitor)
def _checkBackwardsCompatibility(self,
root,
golden_file_pattern,
api_version,
additional_private_map=None,
omit_golden_symbols_map=None):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.private_map['tf'].append('contrib')
if api_version == 2:
public_api_visitor.private_map['tf'].append('enable_v2_behavior')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
if FLAGS.only_test_core_api:
public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
if additional_private_map:
public_api_visitor.private_map.update(additional_private_map)
traverse.traverse(root, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
golden_file_list = file_io.get_matching_files(golden_file_pattern)
if FLAGS.only_test_core_api:
golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
omit_golden_symbols_map)
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens,
api_version=api_version)
def testAPIBackwardsCompatibility(self):
api_version = 1
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
omit_golden_symbols_map = {}
if (api_version == 2 and FLAGS.only_test_core_api
and not _TENSORBOARD_AVAILABLE):
# In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text']
self._checkBackwardsCompatibility(
tf,
golden_file_pattern,
api_version,
# Skip compat.v1 and compat.v2 since they are validated
# in separate tests.
additional_private_map={'tf.compat': ['v1', 'v2']},
omit_golden_symbols_map=omit_golden_symbols_map)
# Check that V2 API does not have contrib
self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
def testAPIBackwardsCompatibilityV1(self):
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
tf.compat.v1, golden_file_pattern, api_version,
additional_private_map={
'tf': ['pywrap_tensorflow'],
'tf.compat': ['v1', 'v2'],
},
omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']})
def testAPIBackwardsCompatibilityV2(self):
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
omit_golden_symbols_map = {}
if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
# In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text']
self._checkBackwardsCompatibility(
tf.compat.v2,
golden_file_pattern,
api_version,
additional_private_map={'tf.compat': ['v1', 'v2']},
omit_golden_symbols_map=omit_golden_symbols_map)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
# TODO(mikecase): Create Estimator's own API compatibility test or
# a more general API compatibility test for use for TF components.
parser.add_argument(
'--only_test_core_api',
type=bool,
default=True, # only_test_core_api default value
help=_ONLY_TEST_CORE_API_HELP)
parser.add_argument(
'--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
FLAGS, unparsed = parser.parse_known_args()
# Now update argv, so that unittest library does not get confused.
sys.argv = [sys.argv[0]] + unparsed
test.main()