blob: f9961debf13c1ecacf2d0ff1ccd3716caf7dd7b1 [file] [log] [blame]
from torch.testing._internal.common_utils import TestCase
import importlib
from typing import List
class AOMigrationTestCase(TestCase):
def _test_package_import(self, package_name: str):
r"""Tests the module import by making sure that all the internals match
(except the dunder methods)."""
old_module = importlib.import_module(f'torch.quantization.{package_name}')
new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
old_module_dir = set(dir(old_module))
new_module_dir = set(dir(new_module))
# Remove magic modules from checking in subsets
for el in list(old_module_dir):
if el[:2] == '__' and el[-2:] == '__':
old_module_dir.remove(el)
assert (old_module_dir <= new_module_dir), \
f"Importing {old_module} vs. {new_module} does not match: " \
f"{old_module_dir - new_module_dir}"
def _test_function_import(self, package_name: str, function_list: List[str]):
r"""Tests individual function list import by comparing the functions
and their hashes."""
old_location = importlib.import_module(f'torch.quantization.{package_name}')
new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
for fn_name in function_list:
old_function = getattr(old_location, fn_name)
new_function = getattr(new_location, fn_name)
assert old_function == new_function, f"Functions don't match: {fn_name}"
assert hash(old_function) == hash(new_function), \
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
f"{new_function}({hash(new_function)})"
class TestAOMigrationQuantizePy(AOMigrationTestCase):
def test_package_import(self):
self._test_package_import('quantize')
def test_function_import(self):
function_list = [
'_convert',
'_observer_forward_hook',
'_propagate_qconfig_helper',
'_remove_activation_post_process',
'_remove_qconfig',
'add_observer_',
'add_quant_dequant',
'convert',
'get_observer_dict',
'get_unique_devices_',
'is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',
'quantize',
'quantize_dynamic',
'quantize_qat',
'register_activation_post_process_hook',
'swap_module',
]
self._test_function_import('quantize', function_list)
def test_package_import_stubs(self):
self._test_package_import('stubs')
def test_function_import_stubs(self):
function_list = [
'QuantStub',
'DeQuantStub',
'QuantWrapper',
]
self._test_function_import('stubs', function_list)