blob: 7b7c05b1eb27ce6416d3149d5cebc497f08f9974 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["module: autograd"]
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
import pkgutil
import torch
import sys
from typing import Callable
import inspect
import json
import os
import unittest
class TestPublicBindings(TestCase):
def test_no_new_bindings(self):
"""
This test aims to stop the introduction of new JIT bindings into torch._C
whose names do not start with _. Such bindings are made available as
torch.XXX, which may not be desirable.
If your change causes this test to fail, add your new binding to a relevant
submodule of torch._C, such as torch._C._jit (or other relevant submodule of
torch._C). If your binding really needs to be available as torch.XXX, add it
to torch._C and add it to the allowlist below.
If you have removed a binding, remove it from the allowlist as well.
"""
# This allowlist contains every binding in torch._C that is copied into torch at
# the time of writing. It was generated with
#
# {elem for elem in dir(torch._C) if not elem.startswith("_")}
#
torch_C_allowlist_superset = {
"AggregationType",
"AliasDb",
"AnyType",
"Argument",
"ArgumentSpec",
"autocast_decrement_nesting",
"autocast_increment_nesting",
"AVG",
"BenchmarkConfig",
"BenchmarkExecutionStats",
"BFloat16StorageBase",
"Block",
"BoolStorageBase",
"BoolType",
"BufferDict",
"ByteStorageBase",
"CallStack",
"Capsule",
"CharStorageBase",
"ClassType",
"clear_autocast_cache",
"Code",
"CompilationUnit",
"CompleteArgumentSpec",
"ComplexDoubleStorageBase",
"ComplexFloatStorageBase",
"ComplexType",
"ConcreteModuleType",
"ConcreteModuleTypeBuilder",
"CONV_BN_FUSION",
"cpp",
"CudaBFloat16StorageBase",
"CudaBFloat16TensorBase",
"CudaBFloat16TensorBase",
"CudaBoolStorageBase",
"CudaBoolTensorBase",
"CudaBoolTensorBase",
"CudaByteStorageBase",
"CudaByteTensorBase",
"CudaByteTensorBase",
"CudaCharStorageBase",
"CudaCharTensorBase",
"CudaCharTensorBase",
"CudaComplexDoubleStorageBase",
"CudaComplexDoubleTensorBase",
"CudaComplexDoubleTensorBase",
"CudaComplexFloatStorageBase",
"CudaComplexFloatTensorBase",
"CudaComplexFloatTensorBase",
"CudaDoubleStorageBase",
"CudaDoubleTensorBase",
"CudaDoubleTensorBase",
"CudaFloatStorageBase",
"CudaFloatTensorBase",
"CudaHalfStorageBase",
"CudaHalfTensorBase",
"CudaIntStorageBase",
"CudaIntTensorBase",
"CudaIntTensorBase",
"CudaLongStorageBase",
"CudaLongTensorBase",
"CudaLongTensorBase",
"CudaShortStorageBase",
"CudaShortTensorBase",
"CudaShortTensorBase",
"DeepCopyMemoTable",
"default_generator",
"DeserializationStorageContext",
"device",
"DeviceObjType",
"DictType",
"DisableTorchFunction",
"DoubleStorageBase",
"dtype",
"EnumType",
"ErrorReport",
"ExecutionPlan",
"FatalError",
"FileCheck",
"finfo",
"FloatStorageBase",
"FloatType",
"fork",
"FunctionSchema",
"FUSE_ADD_RELU",
"Future",
"FutureType",
"Generator",
"get_autocast_cpu_dtype",
"get_default_dtype",
"get_num_interop_threads",
"get_num_threads",
"Gradient",
"Graph",
"GraphExecutorState",
"HalfStorageBase",
"has_cuda",
"has_cudnn",
"has_lapack",
"has_mkl",
"has_mkldnn",
"has_mlc",
"has_openmp",
"has_spectral",
"HOIST_CONV_PACKED_PARAMS",
"iinfo",
"import_ir_module_from_buffer",
"import_ir_module",
"InferredType",
"init_num_threads",
"INSERT_FOLD_PREPACK_OPS",
"InterfaceType",
"IntStorageBase",
"IntType",
"SymIntType",
"IODescriptor",
"is_anomaly_enabled",
"is_autocast_cache_enabled",
"is_autocast_cpu_enabled",
"is_autocast_enabled",
"is_grad_enabled",
"is_inference_mode_enabled",
"JITException",
"layout",
"ListType",
"LiteScriptModule",
"LockingLogger",
"LoggerBase",
"LongStorageBase",
"memory_format",
"merge_type_from_type_comment",
"MobileOptimizerType",
"ModuleDict",
"Node",
"NoneType",
"NoopLogger",
"NumberType",
"OperatorInfo",
"OptionalType",
"ParameterDict",
"parse_ir",
"parse_schema",
"parse_type_comment",
"PyObjectType",
"PyTorchFileReader",
"PyTorchFileWriter",
"QInt32StorageBase",
"QInt8StorageBase",
"qscheme",
"QUInt4x2StorageBase",
"QUInt2x4StorageBase",
"QUInt8StorageBase",
"read_vitals",
"REMOVE_DROPOUT",
"RRefType",
"ScriptClass",
"ScriptClassFunction",
"ScriptDict",
"ScriptDictIterator",
"ScriptDictKeyIterator",
"ScriptList",
"ScriptListIterator",
"ScriptFunction",
"ScriptMethod",
"ScriptModule",
"ScriptModuleSerializer",
"ScriptObject",
"ScriptObjectProperty",
"SerializationStorageContext",
"set_anomaly_enabled",
"set_autocast_cache_enabled",
"set_autocast_cpu_dtype",
"set_autocast_cpu_enabled",
"set_autocast_enabled",
"set_flush_denormal",
"set_num_interop_threads",
"set_num_threads",
"set_vital",
"ShortStorageBase",
"Size",
"StaticModule",
"Stream",
"StreamObjType",
"StringType",
"SUM",
"TensorType",
"ThroughputBenchmark",
"TracingState",
"TupleType",
"Type",
"unify_type_list",
"UnionType",
"Use",
"Value",
"autocast_decrement_nesting",
"autocast_increment_nesting",
"clear_autocast_cache",
"cpp",
"default_generator",
"device",
"dtype",
"finfo",
"fork",
"get_default_dtype",
"get_num_interop_threads",
"get_num_threads",
"has_cuda",
"has_cudnn",
"has_lapack",
"has_mkl",
"has_mkldnn",
"has_mlc",
"has_openmp",
"iinfo",
"import_ir_module",
"import_ir_module_from_buffer",
"init_num_threads",
"is_anomaly_enabled",
"is_autocast_enabled",
"is_grad_enabled",
"layout",
"memory_format",
"merge_type_from_type_comment",
"parse_ir",
"parse_schema",
"parse_type_comment",
"qscheme",
"set_anomaly_enabled",
"set_autocast_enabled",
'set_autocast_gpu_dtype',
'get_autocast_gpu_dtype',
"set_flush_denormal",
"set_num_interop_threads",
"set_num_threads",
"unify_type_list",
"vitals_enabled",
"wait",
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
# Check that the torch._C bindings are all in the allowlist. Since
# bindings can change based on how PyTorch was compiled (e.g. with/without
# CUDA), the two may not be an exact match but the bindings should be
# a subset of the allowlist.
difference = torch_C_bindings.difference(torch_C_allowlist_superset)
msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)
# AttributeError: module 'torch.distributed' has no attribute '_shard'
@unittest.skipIf(IS_WINDOWS, "Distributed Attribute Error")
def test_correct_module_names(self):
'''
An API is considered public, if its `__module__` starts with `torch.`
and there is no name in `__module__` or the object itself that starts with “_”.
Each public package should either:
- (preferred) Define `__all__` and all callables and classes in there must have their
`__module__` start with the current submodule's path. Things not in `__all__` should
NOT have their `__module__` start with the current submodule.
- (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
`__module__` that start with the current submodule.
'''
failure_list = []
with open(os.path.join(os.path.dirname(__file__), 'allowlist_for_publicAPI.json')) as json_file:
# no new entries should be added to this allow_dict.
# New APIs must follow the public API guidelines.
allow_dict = json.load(json_file)
def test_module(modname):
split_strs = modname.split('.')
mod = sys.modules.get(modname)
for elem in split_strs:
if elem.startswith("_"):
return
def add_to_failure_list_if_not_in_allow_dict(modname, elem, elem_module):
if modname in allow_dict and elem in allow_dict[modname]:
return
failure_list.append((modname, elem, elem_module))
# verifies that each public API has the correct module name and naming semantics
def looks_public_or_not(elem, modname, mod, is_public=True):
obj = getattr(mod, elem)
if not (isinstance(obj, Callable) or inspect.isclass(obj)):
return
elem_module = getattr(obj, '__module__', None)
elem_modname_starts_with_mod = elem_module is not None and \
elem_module.startswith(modname) and '._' not in elem_module
# elem's name must NOT begin with an `_` and it's module name
# SHOULD start with it's current module since it's a public API
looks_public = not elem.startswith('_') and elem_modname_starts_with_mod
if is_public != looks_public:
add_to_failure_list_if_not_in_allow_dict(modname, elem, elem_module)
if hasattr(mod, '__all__'):
public_api = mod.__all__
all_api = dir(mod)
for elem in all_api:
looks_public_or_not(elem, modname, mod, is_public=elem in public_api)
else:
all_api = dir(mod)
for elem in all_api:
if not elem.startswith('_'):
looks_public_or_not(elem, modname, mod, is_public=True)
for _, modname, ispkg in pkgutil.walk_packages(path=torch.__path__, prefix=torch.__name__ + '.'):
test_module(modname)
test_module('torch')
msg = "Following new APIs ( displayed in the form (module, element, element module) )" \
" were added that do not meet our guidelines for public API" \
" Please review https://docs.google.com/document/d/10yx2-4gs0gTMOimVS403MnoAWkqitS8TUHX73PN8EjE/edit?pli=1#" \
" for more information:\n" + "\n".join(map(str, failure_list))
# empty lists are considered false in python
self.assertTrue(not failure_list, msg)
if __name__ == '__main__':
run_tests()