| # Owner(s): ["oncall: quantization"] |
| |
| from .common import AOMigrationTestCase |
| |
| |
| class TestAOMigrationQuantization(AOMigrationTestCase): |
| r"""Modules and functions related to the |
| `torch/quantization` migration to `torch/ao/quantization`. |
| """ |
| |
| def test_function_import_quantize(self): |
| function_list = [ |
| "_convert", |
| "_observer_forward_hook", |
| "_propagate_qconfig_helper", |
| "_remove_activation_post_process", |
| "_remove_qconfig", |
| "_add_observer_", |
| "add_quant_dequant", |
| "convert", |
| "_get_observer_dict", |
| "_get_unique_devices_", |
| "_is_activation_post_process", |
| "prepare", |
| "prepare_qat", |
| "propagate_qconfig_", |
| "quantize", |
| "quantize_dynamic", |
| "quantize_qat", |
| "_register_activation_post_process_hook", |
| "swap_module", |
| ] |
| self._test_function_import("quantize", function_list) |
| |
| def test_function_import_stubs(self): |
| function_list = [ |
| "QuantStub", |
| "DeQuantStub", |
| "QuantWrapper", |
| ] |
| self._test_function_import("stubs", function_list) |
| |
| def test_function_import_quantize_jit(self): |
| function_list = [ |
| "_check_is_script_module", |
| "_check_forward_method", |
| "script_qconfig", |
| "script_qconfig_dict", |
| "fuse_conv_bn_jit", |
| "_prepare_jit", |
| "prepare_jit", |
| "prepare_dynamic_jit", |
| "_convert_jit", |
| "convert_jit", |
| "convert_dynamic_jit", |
| "_quantize_jit", |
| "quantize_jit", |
| "quantize_dynamic_jit", |
| ] |
| self._test_function_import("quantize_jit", function_list) |
| |
| def test_function_import_fake_quantize(self): |
| function_list = [ |
| "_is_per_channel", |
| "_is_per_tensor", |
| "_is_symmetric_quant", |
| "FakeQuantizeBase", |
| "FakeQuantize", |
| "FixedQParamsFakeQuantize", |
| "FusedMovingAvgObsFakeQuantize", |
| "default_fake_quant", |
| "default_weight_fake_quant", |
| "default_fixed_qparams_range_neg1to1_fake_quant", |
| "default_fixed_qparams_range_0to1_fake_quant", |
| "default_per_channel_weight_fake_quant", |
| "default_histogram_fake_quant", |
| "default_fused_act_fake_quant", |
| "default_fused_wt_fake_quant", |
| "default_fused_per_channel_wt_fake_quant", |
| "_is_fake_quant_script_module", |
| "disable_fake_quant", |
| "enable_fake_quant", |
| "disable_observer", |
| "enable_observer", |
| ] |
| self._test_function_import("fake_quantize", function_list) |
| |
| def test_function_import_fuse_modules(self): |
| function_list = [ |
| "_fuse_modules", |
| "_get_module", |
| "_set_module", |
| "fuse_conv_bn", |
| "fuse_conv_bn_relu", |
| "fuse_known_modules", |
| "fuse_modules", |
| "get_fuser_method", |
| ] |
| self._test_function_import("fuse_modules", function_list) |
| |
| def test_function_import_quant_type(self): |
| function_list = [ |
| "QuantType", |
| "_get_quant_type_to_str", |
| ] |
| self._test_function_import("quant_type", function_list) |
| |
| def test_function_import_observer(self): |
| function_list = [ |
| "_PartialWrapper", |
| "_with_args", |
| "_with_callable_args", |
| "ABC", |
| "ObserverBase", |
| "_ObserverBase", |
| "MinMaxObserver", |
| "MovingAverageMinMaxObserver", |
| "PerChannelMinMaxObserver", |
| "MovingAveragePerChannelMinMaxObserver", |
| "HistogramObserver", |
| "PlaceholderObserver", |
| "RecordingObserver", |
| "NoopObserver", |
| "_is_activation_post_process", |
| "_is_per_channel_script_obs_instance", |
| "get_observer_state_dict", |
| "load_observer_state_dict", |
| "default_observer", |
| "default_placeholder_observer", |
| "default_debug_observer", |
| "default_weight_observer", |
| "default_histogram_observer", |
| "default_per_channel_weight_observer", |
| "default_dynamic_quant_observer", |
| "default_float_qparams_observer", |
| ] |
| self._test_function_import("observer", function_list) |
| |
| def test_function_import_qconfig(self): |
| function_list = [ |
| "QConfig", |
| "default_qconfig", |
| "default_debug_qconfig", |
| "default_per_channel_qconfig", |
| "QConfigDynamic", |
| "default_dynamic_qconfig", |
| "float16_dynamic_qconfig", |
| "float16_static_qconfig", |
| "per_channel_dynamic_qconfig", |
| "float_qparams_weight_only_qconfig", |
| "default_qat_qconfig", |
| "default_weight_only_qconfig", |
| "default_activation_only_qconfig", |
| "default_qat_qconfig_v2", |
| "get_default_qconfig", |
| "get_default_qat_qconfig", |
| "_assert_valid_qconfig", |
| "QConfigAny", |
| "_add_module_to_qconfig_obs_ctr", |
| "qconfig_equals", |
| ] |
| self._test_function_import("qconfig", function_list) |
| |
| def test_function_import_quantization_mappings(self): |
| function_list = [ |
| "no_observer_set", |
| "get_default_static_quant_module_mappings", |
| "get_static_quant_module_class", |
| "get_dynamic_quant_module_class", |
| "get_default_qat_module_mappings", |
| "get_default_dynamic_quant_module_mappings", |
| "get_default_qconfig_propagation_list", |
| "get_default_compare_output_module_list", |
| "get_default_float_to_quantized_operator_mappings", |
| "get_quantized_operator", |
| "_get_special_act_post_process", |
| "_has_special_act_post_process", |
| ] |
| dict_list = [ |
| "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", |
| "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", |
| "DEFAULT_QAT_MODULE_MAPPINGS", |
| "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", |
| # "_INCLUDE_QCONFIG_PROPAGATE_LIST", |
| "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", |
| "DEFAULT_MODULE_TO_ACT_POST_PROCESS", |
| ] |
| self._test_function_import("quantization_mappings", function_list) |
| self._test_dict_import("quantization_mappings", dict_list) |
| |
| def test_function_import_fuser_method_mappings(self): |
| function_list = [ |
| "fuse_conv_bn", |
| "fuse_conv_bn_relu", |
| "fuse_linear_bn", |
| "get_fuser_method", |
| ] |
| dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"] |
| self._test_function_import("fuser_method_mappings", function_list) |
| self._test_dict_import("fuser_method_mappings", dict_list) |
| |
| def test_function_import_utils(self): |
| function_list = [ |
| "activation_dtype", |
| "activation_is_int8_quantized", |
| "activation_is_statically_quantized", |
| "calculate_qmin_qmax", |
| "check_min_max_valid", |
| "get_combined_dict", |
| "get_qconfig_dtypes", |
| "get_qparam_dict", |
| "get_quant_type", |
| "get_swapped_custom_module_class", |
| "getattr_from_fqn", |
| "is_per_channel", |
| "is_per_tensor", |
| "weight_dtype", |
| "weight_is_quantized", |
| "weight_is_statically_quantized", |
| ] |
| self._test_function_import("utils", function_list) |