| from .common import AOMigrationTestCase |
| |
| |
| class TestAOMigrationQuantization(AOMigrationTestCase): |
| r"""Modules and functions related to the |
| `torch/quantization` migration to `torch/ao/quantization`. |
| """ |
| def test_package_import_quantize(self): |
| self._test_package_import('quantize') |
| |
| def test_function_import_quantize(self): |
| function_list = [ |
| '_convert', |
| '_observer_forward_hook', |
| '_propagate_qconfig_helper', |
| '_remove_activation_post_process', |
| '_remove_qconfig', |
| 'add_observer_', |
| 'add_quant_dequant', |
| 'convert', |
| 'get_observer_dict', |
| 'get_unique_devices_', |
| 'is_activation_post_process', |
| 'prepare', |
| 'prepare_qat', |
| 'propagate_qconfig_', |
| 'quantize', |
| 'quantize_dynamic', |
| 'quantize_qat', |
| 'register_activation_post_process_hook', |
| 'swap_module', |
| ] |
| self._test_function_import('quantize', function_list) |
| |
| def test_package_import_stubs(self): |
| self._test_package_import('stubs') |
| |
| def test_function_import_stubs(self): |
| function_list = [ |
| 'QuantStub', |
| 'DeQuantStub', |
| 'QuantWrapper', |
| ] |
| self._test_function_import('stubs', function_list) |
| |
| def test_package_import_quantize_jit(self): |
| self._test_package_import('quantize_jit') |
| |
| def test_function_import_quantize_jit(self): |
| function_list = [ |
| '_check_is_script_module', |
| '_check_forward_method', |
| 'script_qconfig', |
| 'script_qconfig_dict', |
| 'fuse_conv_bn_jit', |
| '_prepare_jit', |
| 'prepare_jit', |
| 'prepare_dynamic_jit', |
| '_convert_jit', |
| 'convert_jit', |
| 'convert_dynamic_jit', |
| '_quantize_jit', |
| 'quantize_jit', |
| 'quantize_dynamic_jit', |
| ] |
| self._test_function_import('quantize_jit', function_list) |
| |
| def test_package_import_fake_quantize(self): |
| self._test_package_import('fake_quantize') |
| |
| def test_function_import_fake_quantize(self): |
| function_list = [ |
| '_is_per_channel', |
| '_is_per_tensor', |
| '_is_symmetric_quant', |
| 'FakeQuantizeBase', |
| 'FakeQuantize', |
| 'FixedQParamsFakeQuantize', |
| 'FusedMovingAvgObsFakeQuantize', |
| 'default_fake_quant', |
| 'default_weight_fake_quant', |
| 'default_symmetric_fixed_qparams_fake_quant', |
| 'default_affine_fixed_qparams_fake_quant', |
| 'default_per_channel_weight_fake_quant', |
| 'default_histogram_fake_quant', |
| 'default_fused_act_fake_quant', |
| 'default_fused_wt_fake_quant', |
| 'default_fused_per_channel_wt_fake_quant', |
| '_is_fake_quant_script_module', |
| 'disable_fake_quant', |
| 'enable_fake_quant', |
| 'disable_observer', |
| 'enable_observer', |
| ] |
| self._test_function_import('fake_quantize', function_list) |
| |
| def test_package_import_fuse_modules(self): |
| self._test_package_import('fuse_modules') |
| |
| def test_function_import_fuse_modules(self): |
| function_list = [ |
| '_fuse_modules', |
| '_get_module', |
| '_set_module', |
| 'fuse_conv_bn', |
| 'fuse_conv_bn_relu', |
| 'fuse_known_modules', |
| 'fuse_modules', |
| 'get_fuser_method', |
| ] |
| self._test_function_import('fuse_modules', function_list) |
| |
| def test_package_import_quant_type(self): |
| self._test_package_import('quant_type') |
| |
| def test_function_import_quant_type(self): |
| function_list = [ |
| 'QuantType', |
| 'quant_type_to_str', |
| ] |
| self._test_function_import('quant_type', function_list) |
| |
| def test_package_import_observer(self): |
| self._test_package_import('observer') |
| |
| def test_function_import_observer(self): |
| function_list = [ |
| "_PartialWrapper", |
| "_with_args", |
| "_with_callable_args", |
| "ABC", |
| "ObserverBase", |
| "_ObserverBase", |
| "MinMaxObserver", |
| "MovingAverageMinMaxObserver", |
| "PerChannelMinMaxObserver", |
| "MovingAveragePerChannelMinMaxObserver", |
| "HistogramObserver", |
| "PlaceholderObserver", |
| "RecordingObserver", |
| "NoopObserver", |
| "_is_activation_post_process", |
| "_is_per_channel_script_obs_instance", |
| "get_observer_state_dict", |
| "load_observer_state_dict", |
| "default_observer", |
| "default_placeholder_observer", |
| "default_debug_observer", |
| "default_weight_observer", |
| "default_histogram_observer", |
| "default_per_channel_weight_observer", |
| "default_dynamic_quant_observer", |
| "default_float_qparams_observer", |
| ] |
| self._test_function_import('observer', function_list) |
| |
| def test_package_import_qconfig(self): |
| self._test_package_import('qconfig') |
| |
| def test_function_import_qconfig(self): |
| function_list = [ |
| "QConfig", |
| "default_qconfig", |
| "default_debug_qconfig", |
| "default_per_channel_qconfig", |
| "QConfigDynamic", |
| "default_dynamic_qconfig", |
| "float16_dynamic_qconfig", |
| "float16_static_qconfig", |
| "per_channel_dynamic_qconfig", |
| "float_qparams_weight_only_qconfig", |
| "default_qat_qconfig", |
| "default_weight_only_qconfig", |
| "default_activation_only_qconfig", |
| "default_qat_qconfig_v2", |
| "get_default_qconfig", |
| "get_default_qat_qconfig", |
| "assert_valid_qconfig", |
| "QConfigAny", |
| "add_module_to_qconfig_obs_ctr", |
| "qconfig_equals" |
| ] |
| self._test_function_import('qconfig', function_list) |
| |
| def test_package_import_quantization_mappings(self): |
| self._test_package_import('quantization_mappings') |
| |
| def test_function_import_quantization_mappings(self): |
| function_list = [ |
| "no_observer_set", |
| "get_default_static_quant_module_mappings", |
| "get_static_quant_module_class", |
| "get_dynamic_quant_module_class", |
| "get_default_qat_module_mappings", |
| "get_default_dynamic_quant_module_mappings", |
| "get_default_qconfig_propagation_list", |
| "get_default_compare_output_module_list", |
| "get_default_float_to_quantized_operator_mappings", |
| "get_quantized_operator", |
| "_get_special_act_post_process", |
| "_has_special_act_post_process", |
| ] |
| dict_list = [ |
| "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", |
| "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", |
| "DEFAULT_QAT_MODULE_MAPPINGS", |
| "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", |
| # "_INCLUDE_QCONFIG_PROPAGATE_LIST", |
| "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", |
| "DEFAULT_MODULE_TO_ACT_POST_PROCESS", |
| ] |
| self._test_function_import('quantization_mappings', function_list) |
| self._test_dict_import('quantization_mappings', dict_list) |
| |
| def test_package_import_fuser_method_mappings(self): |
| self._test_package_import('fuser_method_mappings') |
| |
| def test_function_import_fuser_method_mappings(self): |
| function_list = [ |
| "fuse_conv_bn", |
| "fuse_conv_bn_relu", |
| "fuse_linear_bn", |
| "get_fuser_method", |
| ] |
| dict_list = [ |
| "DEFAULT_OP_LIST_TO_FUSER_METHOD" |
| ] |
| self._test_function_import('fuser_method_mappings', function_list) |
| self._test_dict_import('fuser_method_mappings', dict_list) |
| |
| def test_package_import_utils(self): |
| self._test_package_import('utils') |
| |
| def test_function_import_utils(self): |
| function_list = [ |
| 'activation_dtype', |
| 'activation_is_int8_quantized', |
| 'activation_is_statically_quantized', |
| 'calculate_qmin_qmax', |
| 'check_min_max_valid', |
| 'get_combined_dict', |
| 'get_qconfig_dtypes', |
| 'get_qparam_dict', |
| 'get_quant_type', |
| 'get_swapped_custom_module_class', |
| 'getattr_from_fqn', |
| 'is_per_channel', |
| 'is_per_tensor', |
| 'weight_dtype', |
| 'weight_is_quantized', |
| 'weight_is_statically_quantized', |
| ] |
| self._test_function_import('utils', function_list) |