| # Owner(s): ["oncall: quantization"] |
| from typing import List, Tuple |
| |
| import torch |
| from torch._export import ( |
| capture_pre_autograd_graph, |
| ) |
| from torch import Tensor |
| from torch.ao.quantization import ( |
| observer, |
| ObserverOrFakeQuantize, |
| QConfigMapping, |
| ) |
| from torch.ao.quantization.quantizer import ( |
| DerivedQuantizationSpec, |
| FixedQParamsQuantizationSpec, |
| QuantizationAnnotation, |
| QuantizationSpec, |
| Quantizer, |
| SharedQuantizationSpec, |
| ) |
| from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| XNNPACKQuantizer, |
| get_symmetric_quantization_config, |
| ) |
| from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| OP_TO_ANNOTATOR, |
| QuantizationConfig, |
| ) |
| from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811 |
| ComposableQuantizer, |
| ) |
| from torch.ao.quantization.quantizer.embedding_quantizer import ( # noqa: F811 |
| EmbeddingQuantizer, |
| ) |
| |
| from torch.ao.quantization.quantize_pt2e import ( |
| convert_pt2e, |
| prepare_pt2e, |
| prepare_qat_pt2e, |
| ) |
| |
| from torch.ao.quantization.qconfig import ( |
| default_per_channel_symmetric_qnnpack_qconfig, |
| float_qparams_weight_only_qconfig, |
| per_channel_weight_observer_range_neg_127_to_127, |
| QConfig, |
| weight_observer_range_neg_127_to_127, |
| ) |
| from torch.fx import Node |
| |
| from torch.testing._internal.common_quantization import ( |
| NodeSpec as ns, |
| PT2EQuantizationTestCase, |
| skipIfNoQNNPACK, |
| TestHelperModules, |
| ) |
| from torch.testing._internal.common_utils import ( |
| TemporaryFileName, |
| ) |
| |
| |
| @skipIfNoQNNPACK |
| class TestQuantizePT2E(PT2EQuantizationTestCase): |
| def test_simple_quantizer(self): |
| # TODO: use OP_TO_ANNOTATOR |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| node_occurrence = { |
| # two for input of the first conv, one for output for the first conv |
| torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, |
| } |
| node_list = [ |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| torch.ops.aten.conv2d.default, |
| torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| ] |
| self._test_quantizer( |
| TestHelperModules.ConvWithBNRelu(relu=False, bn=False), |
| example_inputs, |
| BackendAQuantizer(), |
| node_occurrence, |
| node_list, |
| ) |
| |
| def test_wo_annotate_conv_output_quantizer(self): |
| # TODO: use OP_TO_ANNOTATOR |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = torch.nn.Conv2d(2, 2, 1) |
| x = torch.rand(1, 2, 14, 14) |
| example_inputs = (x,) |
| m = self._quantize(m, BackendAQuantizer(), example_inputs) |
| # Ensure the conv has no observer inserted at output |
| node_occurrence = { |
| # two for input of conv |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 1, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, |
| } |
| node_list = [ |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.aten.conv2d.default), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_max_pool2d_quantizer(self): |
| # TODO: use OP_TO_ANNOTATOR |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| _annotated=True, |
| ) |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.max_pool2d.default |
| ): |
| maxpool_node = node |
| input_act = maxpool_node.args[0] |
| assert isinstance(input_act, Node) |
| maxpool_node.meta[ |
| "quantization_annotation" |
| ] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| }, |
| output_qspec=SharedQuantizationSpec( |
| (input_act, maxpool_node) |
| ), |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = TestHelperModules.ConvMaxPool2d() |
| x = torch.rand(1, 2, 14, 14) |
| example_inputs = (x,) |
| m = self._quantize(m, BackendAQuantizer(), example_inputs) |
| node_occurrence = { |
| # two for input of conv |
| # one for input of maxpool |
| # one for output of maxpool |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4, |
| } |
| node_list = [ |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.aten.conv2d.default), |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default), |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.aten.max_pool2d.default), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_derived_qspec(self): |
| # TODO: use OP_TO_ANNOTATOR |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| |
| def derive_qparams_fn( |
| obs_or_fqs: List[ObserverOrFakeQuantize], |
| ) -> Tuple[Tensor, Tensor]: |
| assert ( |
| len(obs_or_fqs) == 2 |
| ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" |
| act_obs_or_fq = obs_or_fqs[0] |
| weight_obs_or_fq = obs_or_fqs[1] |
| act_scale, act_zp = act_obs_or_fq.calculate_qparams() |
| ( |
| weight_scale, |
| weight_zp, |
| ) = weight_obs_or_fq.calculate_qparams() |
| return torch.tensor([act_scale * weight_scale]).to( |
| torch.float32 |
| ), torch.tensor([0]).to(torch.int32) |
| |
| bias_qspec = DerivedQuantizationSpec( |
| derived_from=[(input_act, node), (weight, node)], |
| derive_qparams_fn=derive_qparams_fn, |
| dtype=torch.int32, |
| quant_min=-(2**31), |
| quant_max=2**31 - 1, |
| qscheme=torch.per_tensor_symmetric, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| m = self._quantize(m, BackendAQuantizer(), example_inputs) |
| node_occurrence = { |
| # input, weight, bias, output for the conv |
| # note: quantize op for weight and bias are const propagated |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): 2, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): 4, |
| } |
| node_list = [ |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function(torch.ops.aten.conv2d.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_derived_qspec_per_channel(self): |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_channel_affine, |
| is_dynamic=False, |
| ch_axis=0, |
| observer_or_fake_quant_ctr=observer.default_per_channel_weight_observer, |
| ) |
| |
| def derive_qparams_fn( |
| obs_or_fqs: List[ObserverOrFakeQuantize], |
| ) -> Tuple[Tensor, Tensor]: |
| assert ( |
| len(obs_or_fqs) == 1 |
| ), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" |
| weight_obs_or_fq = obs_or_fqs[0] |
| ( |
| weight_scale, |
| weight_zp, |
| ) = weight_obs_or_fq.calculate_qparams() |
| return weight_scale, torch.zeros_like(weight_scale) |
| |
| bias_qspec = DerivedQuantizationSpec( |
| derived_from=[(weight, node)], |
| derive_qparams_fn=derive_qparams_fn, |
| dtype=torch.int32, |
| quant_min=-(2**31), |
| quant_max=2**31 - 1, |
| qscheme=torch.per_channel_symmetric, |
| ch_axis=0, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| m = self._quantize(m, BackendAQuantizer(), example_inputs) |
| |
| node_occurrence = { |
| # input, output for the conv |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): 2, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): 2, |
| # weight and bias for conv |
| # note: quantize op for weight and bias are const propagated |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_channel.default |
| ): 0, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_channel.default |
| ): 2, |
| } |
| node_list = [ |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_channel.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_channel.default |
| ), |
| ns.call_function(torch.ops.aten.conv2d.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_fixed_qparams_qspec(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return torch.sigmoid(x) |
| |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.sigmoid.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| act_qspec = FixedQParamsQuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| scale=1.0 / 256.0, |
| zero_point=0, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = M().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| m = self._quantize(m, BackendAQuantizer(), example_inputs) |
| fixed_scale = 1.0 / 256.0 |
| fixed_zero_point = 0 |
| for n in m.graph.nodes: |
| if n.op == "call_function": |
| if ( |
| n.target |
| == torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): |
| scale_0 = n.args[1] |
| zero_point_0 = n.args[2] |
| if ( |
| n.target |
| == torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): |
| scale_1 = n.args[1] |
| zero_point_1 = n.args[2] |
| self.assertEqual(scale_0, fixed_scale) |
| self.assertEqual(zero_point_0, fixed_zero_point) |
| self.assertEqual(scale_1, fixed_scale) |
| self.assertEqual(zero_point_1, fixed_zero_point) |
| node_occurrence = { |
| # two for input of the first conv, one for output for the first conv |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): 2, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): 2, |
| } |
| node_list = [ |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function(torch.ops.aten.sigmoid.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_shared_qspec(self): |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| elif node.target is torch.ops.aten.cat.default: |
| cat_node = node |
| input_nodes = cat_node.args[0] |
| first_input_node = input_nodes[0] |
| input_qspec_map = {} |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| input_qspec_map[first_input_node] = act_qspec |
| share_qparams_with_input_act0_qspec = SharedQuantizationSpec((first_input_node, cat_node)) |
| for input_node in input_nodes[1:]: |
| input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| |
| cat_node.meta[ |
| "quantization_annotation" |
| ] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act0_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| |
| m = TestHelperModules.Conv2dWithCat().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) |
| |
| # program capture |
| m = capture_pre_autograd_graph( |
| m, |
| example_inputs, |
| ) |
| m = prepare_pt2e(m, BackendAQuantizer()) |
| # make sure the two observers for input are shared |
| conv_output_obs = [] |
| for n in m.graph.nodes: |
| if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: |
| conv_output_obs.append(getattr(m, next(iter(n.users)).target)) |
| if n.op == "call_function" and n.target == torch.ops.aten.cat.default: |
| inputs = n.args[0] |
| input0 = inputs[0] |
| input1 = inputs[1] |
| assert input0.op == "call_module" |
| assert input1.op == "call_module" |
| obs_ins0 = getattr(m, input0.target) |
| obs_ins1 = getattr(m, input1.target) |
| assert obs_ins0 == obs_ins1 |
| assert len(conv_output_obs) == 2, "expecting two observer that follows conv2d ops" |
| # checking that the output observers for the two convs are shared as well |
| assert conv_output_obs[0] == conv_output_obs[1] |
| |
| m(*example_inputs) |
| m = convert_pt2e(m, fold_quantize=True) |
| |
| node_occurrence = { |
| # two for input of the first conv, one for output for the first conv |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): 5, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): 7, |
| } |
| node_list = [ |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function(torch.ops.aten.cat.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def _test_transitive_sharing_with_cat_helper(self, quantizer): |
| m = TestHelperModules.Conv2dWithTwoCat().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 6, 3, 3), torch.randn(1, 6, 3, 3)) |
| |
| # program capture |
| m = capture_pre_autograd_graph( |
| m, |
| example_inputs, |
| ) |
| m = prepare_pt2e(m, quantizer) |
| m(*example_inputs) |
| # make sure the two input observers and output are shared |
| conv_output_obs = [] |
| for n in m.graph.nodes: |
| if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: |
| conv_output_obs.append(getattr(m, next(iter(n.users)).target)) |
| if n.op == "call_function" and n.target == torch.ops.aten.cat.default: |
| inputs = n.args[0] |
| input0 = inputs[0] |
| input1 = inputs[1] |
| assert input0.op == "call_module" |
| assert input1.op == "call_module" |
| obs_ins0 = getattr(m, input0.target) |
| obs_ins1 = getattr(m, input1.target) |
| assert obs_ins0 == obs_ins1 |
| |
| output_obs = next(iter(n.users)) |
| assert output_obs.op == "call_module" |
| obs_ins2 = getattr(m, output_obs.target) |
| assert obs_ins0 == obs_ins2, "input observer does not match output" |
| |
| assert len(conv_output_obs) == 2, "expecting two observer that follows conv2d ops" |
| # checking that the output observers for the two convs are shared as well |
| assert conv_output_obs[0] == conv_output_obs[1] |
| |
| m(*example_inputs) |
| m = convert_pt2e(m, fold_quantize=True) |
| |
| node_occurrence = { |
| # two for input of the first conv, one for output for the first conv |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ): 7, |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ): 9, |
| } |
| node_list = [ |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function(torch.ops.aten.cat.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ns.call_function( |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| ), |
| ns.call_function(torch.ops.aten.cat.default), |
| ns.call_function( |
| torch.ops.quantized_decomposed.quantize_per_tensor.default |
| ), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_list=node_list, expected_node_occurrence=node_occurrence |
| ) |
| |
| def test_shared_qspec_transitivity(self): |
| """This tests the transitivity of SharedQuantizationSpec, that is |
| if A is shared with B, B is shared with C, then C should be shared with A as well |
| |
| x1 -> conv1 -> cat1 -----> cat2 |
| x2 -> conv2 -/ / |
| x3 -> add / |
| x4 / |
| |
| both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor |
| so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same |
| sharing group after transitive sharing |
| """ |
| # TODO: refactor this to a common util |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| elif node.target is torch.ops.aten.cat.default: |
| cat_node = node |
| input_nodes = cat_node.args[0] |
| first_input_node = input_nodes[0] |
| input_qspec_map = {} |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| input_qspec_map[first_input_node] = act_qspec |
| share_qparams_with_input_act0_qspec = SharedQuantizationSpec((first_input_node, cat_node)) |
| for input_node in input_nodes[1:]: |
| input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| |
| cat_node.meta[ |
| "quantization_annotation" |
| ] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act0_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) |
| |
| def test_shared_qspec_transitivity_case_2(self): |
| """This tests the transitivity of SharedQuantizationSpec, that is |
| if A is shared with B, B is shared with C, then C should be shared with A as well |
| |
| x1 -> conv1 -> cat1 -----> cat2 |
| x2 -> conv2 -/ / |
| x3 -> add / |
| x4 / |
| |
| both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor |
| so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same |
| sharing group after transitive sharing |
| |
| the difference is that for this one, all edges and nodes are shared with the second input edge of cat |
| instead of the first input edge of cat as in previous example |
| """ |
| # TODO: refactor this to a common util |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.conv2d.default |
| ): |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| bias = node.args[2] |
| assert isinstance(bias, Node) |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| bias_qspec = QuantizationSpec( |
| dtype=torch.float32, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.PlaceholderObserver, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| weight: weight_qspec, |
| bias: bias_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| elif node.target is torch.ops.aten.cat.default: |
| cat_node = node |
| input_nodes = cat_node.args[0] |
| first_input_node = input_nodes[0] |
| second_input_node = input_nodes[1] |
| input_qspec_map = {} |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| input_qspec_map[second_input_node] = act_qspec |
| share_qparams_with_input_act1_qspec = SharedQuantizationSpec((second_input_node, cat_node)) |
| input_qspec_map[first_input_node] = share_qparams_with_input_act1_qspec |
| |
| cat_node.meta[ |
| "quantization_annotation" |
| ] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act1_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) |
| |
| def test_allow_implicit_sharing(self): |
| """This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is |
| if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing |
| for node and (node, consumer) even they refer to the same Tensor |
| |
| x1 -> add1 -----> add3 |
| x2 -/ / |
| x3 -> add2 / |
| x4 -/ |
| |
| all add has shared input and output, and second input is using shared quantization spec pointing |
| to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1, |
| add2 and add3 will each belong to one sharing group, so we'll have: |
| |
| x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3 |
| x2 -> obs1 -/ / |
| x3 -> obs2 -> add2 -> obs2 -> obs3 |
| x4 -> obs2 -/ |
| """ |
| # TODO: refactor this to a common util |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if node.target is torch.ops.aten.add.Tensor: |
| add_node = node |
| first_input_node = add_node.args[0] |
| second_input_node = add_node.args[1] |
| input_qspec_map = {} |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| input_qspec_map[second_input_node] = act_qspec |
| share_qparams_with_input_act1_qspec = SharedQuantizationSpec((second_input_node, add_node)) |
| input_qspec_map[first_input_node] = share_qparams_with_input_act1_qspec |
| |
| add_node.meta[ |
| "quantization_annotation" |
| ] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act1_qspec, |
| allow_implicit_sharing=False, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| m = TestHelperModules.ThreeAdd().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) |
| |
| # program capture |
| m = capture_pre_autograd_graph( |
| m, |
| example_inputs, |
| ) |
| quantizer = BackendAQuantizer() |
| m = prepare_pt2e(m, quantizer) |
| m(*example_inputs) |
| observers = [] |
| for n in m.graph.nodes: |
| if n.target == torch.ops.aten.add.Tensor: |
| input_obs1 = getattr(m, n.args[0].target) |
| input_obs2 = getattr(m, n.args[1].target) |
| output_obs = getattr(m, next(iter(n.users)).target) |
| self.assertIs(input_obs1, input_obs2) |
| self.assertIs(input_obs1, output_obs) |
| observers.append(input_obs1) |
| assert len(observers) == 3 |
| self.assertIsNot(observers[0], observers[1]) |
| self.assertIsNot(observers[0], observers[2]) |
| self.assertIsNot(observers[1], observers[2]) |
| |
| def test_int16(self): |
| class Int16ActQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| # using int32 to simulate int16 |
| int16_qspec = QuantizationSpec( |
| dtype=torch.int16, |
| quant_min=-2**15, |
| quant_max=2**15 - 1, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_observer, |
| ) |
| int8_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=-128, |
| quant_max=127, |
| qscheme=torch.per_tensor_symmetric, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer.default_weight_observer, |
| ) |
| quantization_config = QuantizationConfig( |
| input_activation=int16_qspec, |
| weight=int8_qspec, |
| bias=None, |
| output_activation=int16_qspec, |
| ) |
| OP_TO_ANNOTATOR["conv"](model, quantization_config) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| quantizer = Int16ActQuantizer() |
| node_occurrence = { |
| # one for input of the first conv, one for output for the first conv |
| torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, |
| } |
| node_list = [ |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| torch.ops.aten.conv2d.default, |
| torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| ] |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| self._test_quantizer( |
| M().eval(), |
| example_inputs, |
| Int16ActQuantizer(), |
| node_occurrence, |
| node_list, |
| ) |
| |
| def test_fold_quantize(self): |
| """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded) |
| """ |
| m = self._get_pt2e_quantized_linear() |
| node_occurrence = { |
| # quantize op for weight node is folded |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_fold_quantize_per_channel(self): |
| """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded) |
| """ |
| m = self._get_pt2e_quantized_linear(is_per_channel=True) |
| node_occurrence = { |
| # quantize op for weight node is folded |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_dont_fold_other_constant(self): |
| """Make sure the constant propagation does not apply to things unrelated to |
| quantization |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2)) |
| |
| def forward(self, x): |
| t = self.dont_fold_me.t() |
| return self.linear(x) + t |
| |
| quantizer = XNNPACKQuantizer() |
| operator_config = get_symmetric_quantization_config(is_per_channel=False) |
| # only quantize linear, so add is not quantized and the constant Tensor |
| # should not be folded |
| quantizer.set_module_type(torch.nn.Linear, operator_config) |
| example_inputs = (torch.randn(2, 2),) |
| m = M().eval() |
| m = self._quantize(m, quantizer, example_inputs) |
| node_occurrence = { |
| # quantize op for weight node is folded |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, |
| # transpose op not folded |
| ns.call_function(torch.ops.aten.t.default): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_fold_all_ops_before_quantize(self): |
| """Test folding all ops that's before quantized operator: |
| Before: |
| get_attr(weight) -> transpose -> quantize -> dequantize |
| After: |
| get_attr(folded_weight) -> dequantize |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(2, 2) |
| |
| def forward(self, x): |
| t = self.weight.t() |
| return torch.nn.functional.linear(x, t) |
| |
| quantizer = XNNPACKQuantizer() |
| operator_config = get_symmetric_quantization_config(is_per_channel=False) |
| quantizer.set_global(operator_config) |
| example_inputs = (torch.randn(2, 2),) |
| m = M().eval() |
| m = self._quantize(m, quantizer, example_inputs) |
| node_occurrence = { |
| # quantize op for weight node is folded |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_constant_prop_preserve_metadata(self): |
| """Test to make sure the get_attr node for const propagated weight Tensor gets the correct |
| metadata (from original get_attr node from weight) |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| quantizer = XNNPACKQuantizer() |
| operator_config = get_symmetric_quantization_config() |
| quantizer.set_global(operator_config) |
| example_inputs = (torch.randn(2, 2),) |
| m = M().eval() |
| m = capture_pre_autograd_graph( |
| m, |
| example_inputs, |
| ) |
| weight_meta = None |
| for n in m.graph.nodes: |
| if n.op == "get_attr" and next(iter(n.users)).target == torch.ops.aten.linear.default: |
| weight_meta = n.meta |
| break |
| assert weight_meta is not None, "Expect to find metadata for weight node" |
| |
| m = prepare_pt2e(m, quantizer) |
| m(*example_inputs) |
| m = convert_pt2e(m, fold_quantize=True) |
| |
| for n in m.graph.nodes: |
| if n.op == "get_attr" and "frozen_param" in n.target: |
| self.assertIn("stack_trace", n.meta) |
| for key in n.meta: |
| self.assertEqual(n.meta[key], weight_meta[key]) |
| |
| def test_save_load(self): |
| """Test save/load a quantized model |
| """ |
| m = self._get_pt2e_quantized_linear() |
| example_inputs = (torch.randn(2, 2),) |
| ref_res = m(*example_inputs) |
| |
| with TemporaryFileName() as fname: |
| # serialization |
| quantized_ep = torch.export.export(m, example_inputs) |
| torch.export.save(quantized_ep, fname) |
| # deserialization |
| loaded_ep = torch.export.load(fname) |
| loaded_quantized_model = loaded_ep.module() |
| res = loaded_quantized_model(*example_inputs) |
| self.assertEqual(ref_res, res) |
| |
| def test_composable_quantizer_throw(self): |
| class BadQuantizer(Quantizer): |
| def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for n in gm.graph.nodes: |
| n.meta["quantization_annotation"] = None |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| quantizer = XNNPACKQuantizer() |
| quantization_config = get_symmetric_quantization_config(is_per_channel=True) |
| quantizer.set_global(quantization_config) |
| bad_quantizer = BadQuantizer() |
| composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer]) |
| m_eager = TestHelperModules.ConvLinearWPermute().eval() |
| example_inputs = (torch.randn(2, 3, 4, 4),) |
| self.assertRaises( |
| RuntimeError, |
| lambda: self._test_quantizer( |
| m_eager, example_inputs, composable_quantizer, {} |
| ), |
| ) |
| |
| def test_transform_for_annotation(self): |
| class TestQuantizer(Quantizer): |
| def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for n in model.graph.nodes: |
| if n.target == torch.ops.aten.add.Tensor: |
| n.target = torch.ops.aten.mul.Tensor |
| return model |
| |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| return model |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return x + 3 |
| |
| m = M().eval() |
| quantizer = TestQuantizer() |
| example_inputs = (torch.randn(1, 2, 3, 3),) |
| m = capture_pre_autograd_graph(m, example_inputs) |
| m = prepare_pt2e(m, quantizer) |
| m(*example_inputs) |
| node_occurrence = { |
| ns.call_function(torch.ops.aten.add.Tensor): 0, |
| ns.call_function(torch.ops.aten.mul.Tensor): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_embedding_quantizer(self): |
| m_eager = TestHelperModules.EmbeddingModule().eval() |
| indices = torch.tensor( |
| [ |
| 9, |
| 6, |
| 5, |
| 7, |
| 8, |
| 8, |
| 9, |
| 2, |
| 8, |
| 6, |
| 6, |
| 9, |
| 1, |
| 6, |
| 8, |
| 8, |
| 3, |
| 2, |
| 3, |
| 6, |
| 3, |
| 6, |
| 5, |
| 7, |
| 0, |
| 8, |
| 4, |
| 6, |
| 5, |
| 8, |
| 2, |
| 3, |
| ] |
| ) |
| example_inputs = (indices,) |
| |
| quantizer = EmbeddingQuantizer() |
| node_occurrence = { |
| # note: quantize op for weights are const propagated |
| torch.ops.quantized_decomposed.quantize_per_channel.default: 0, |
| torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, |
| } |
| node_list = [ |
| torch.ops.quantized_decomposed.dequantize_per_channel.default, |
| torch.ops.aten.embedding.default, |
| ] |
| # Compare against short term workflow |
| # cannot compare against fx quant because of the numerical differences coming |
| # from quantize and dequantize ops |
| qconfig = default_per_channel_symmetric_qnnpack_qconfig |
| qconfig_mapping = QConfigMapping().set_global(qconfig) |
| qconfig_mapping = qconfig_mapping.set_object_type( |
| torch.nn.Embedding, float_qparams_weight_only_qconfig |
| ) |
| self._test_quantizer( |
| m_eager, |
| example_inputs, |
| quantizer, |
| node_occurrence, |
| node_list, |
| True, |
| qconfig_mapping, |
| ) |
| |
| def test_composable_quantizer_linear_conv(self): |
| dynamic_quantizer = XNNPACKQuantizer() |
| quantization_config_dynamic = get_symmetric_quantization_config( |
| is_per_channel=False, is_dynamic=True |
| ) |
| dynamic_quantizer.set_global(quantization_config_dynamic) |
| static_quantizer = XNNPACKQuantizer() |
| quantization_config = get_symmetric_quantization_config(is_per_channel=True) |
| static_quantizer.set_global(quantization_config) |
| # Note that dynamic quantization must be applied first here. |
| # this is because static quantizer also quantizes linear with static qspec |
| # and if we apply static_quantizer first then dynamic_quantizer cannot be applied |
| composable_quantizer = ComposableQuantizer( |
| [dynamic_quantizer, static_quantizer] |
| ) |
| m_eager = TestHelperModules.ConvLinearWPermute().eval() |
| |
| node_occurrence = { |
| torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, |
| # note: quantize op for weights are const propagated |
| torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, |
| # note: quantize op for weights are const propagated |
| torch.ops.quantized_decomposed.quantize_per_channel.default: 0, |
| torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, |
| } |
| act_affine_quant_obs = observer.PlaceholderObserver.with_args( |
| dtype=torch.qint8, |
| qscheme=torch.per_tensor_affine, |
| quant_min=-128, |
| quant_max=127, |
| eps=2**-12, |
| is_dynamic=True, |
| ) |
| dynamic_qconfig = QConfig( |
| activation=act_affine_quant_obs, |
| weight=weight_observer_range_neg_127_to_127, |
| ) |
| # Test with 2d inputs |
| example_inputs = (torch.randn(2, 3, 4, 4),) |
| qconfig = default_per_channel_symmetric_qnnpack_qconfig |
| qconfig_mapping = QConfigMapping().set_global(qconfig) |
| qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) |
| # Had to turn off check against fx because fx quant workflow does not seem |
| # to propagate observers for permute node for this model. |
| # Suprisingly it does propagate it for EmbeddingConvLinearModule |
| # TODO: Figure out the right behavior for propagation |
| self._test_quantizer( |
| m_eager, |
| example_inputs, |
| composable_quantizer, |
| node_occurrence, |
| [], |
| False, |
| qconfig_mapping, |
| ) |
| |
| def test_embedding_conv_linear_quantization(self): |
| m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() |
| indices = torch.tensor( |
| [ |
| 9, |
| 6, |
| 5, |
| 7, |
| 8, |
| 8, |
| 9, |
| 2, |
| 8, |
| 6, |
| 6, |
| 9, |
| 1, |
| 6, |
| 8, |
| 8, |
| 3, |
| 2, |
| 3, |
| 6, |
| 3, |
| 6, |
| 5, |
| 7, |
| 0, |
| 8, |
| 4, |
| 6, |
| 5, |
| 8, |
| 2, |
| 3, |
| ] |
| ) |
| indices = torch.unsqueeze(indices, 0) |
| example_inputs = (indices,) |
| |
| embedding_quantizer = EmbeddingQuantizer() |
| dynamic_quantizer = XNNPACKQuantizer() |
| quantization_config_dynamic = get_symmetric_quantization_config( |
| is_per_channel=True, is_dynamic=True |
| ) |
| dynamic_quantizer.set_global(quantization_config_dynamic) |
| static_quantizer = XNNPACKQuantizer() |
| quantization_config = get_symmetric_quantization_config(is_per_channel=True) |
| static_quantizer.set_global(quantization_config) |
| composed_quantizer = ComposableQuantizer( |
| [embedding_quantizer, dynamic_quantizer, static_quantizer] |
| ) |
| |
| act_affine_quant_obs = observer.PlaceholderObserver.with_args( |
| dtype=torch.qint8, |
| qscheme=torch.per_tensor_affine, |
| quant_min=-128, |
| quant_max=127, |
| eps=2**-12, |
| is_dynamic=True, |
| ) |
| dynamic_qconfig = QConfig( |
| activation=act_affine_quant_obs, |
| weight=per_channel_weight_observer_range_neg_127_to_127, |
| ) |
| qconfig = default_per_channel_symmetric_qnnpack_qconfig |
| qconfig_mapping = QConfigMapping().set_global(qconfig) |
| qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) |
| qconfig_mapping = qconfig_mapping.set_object_type( |
| torch.nn.Embedding, float_qparams_weight_only_qconfig |
| ) |
| |
| node_occurrence = { |
| torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, |
| torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, |
| torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, |
| # note: quantize op for weights are const propagated |
| torch.ops.quantized_decomposed.quantize_per_channel.default: 0, |
| torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, |
| } |
| self._test_quantizer( |
| m_eager, |
| example_inputs, |
| composed_quantizer, |
| node_occurrence, |
| [], |
| True, |
| qconfig_mapping, |
| ) |
| |
| def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload): |
| """ |
| Return the first node matching the specified target, throwing an exception |
| if no such batch norm node is found. |
| """ |
| for n in m.graph.nodes: |
| if n.target == target: |
| return n |
| raise ValueError("Did not find node with target ", target) |
| |
| def _test_move_exported_model_dropout(self, inplace: bool): |
| """ |
| Test switching dropout behavior between train and eval modes using |
| `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = torch.nn.Dropout(0.5, inplace=inplace) |
| |
| def forward(self, x): |
| return self.dropout(x) |
| |
| example_inputs = (torch.randn(1),) |
| m = M().train() |
| m = capture_pre_autograd_graph(m, example_inputs) |
| if inplace: |
| target = torch.ops.aten.dropout_.default |
| else: |
| target = torch.ops.aten.dropout.default |
| |
| # Assert that dropout op exists and is in train mode |
| dropout_node = self._get_node(m, target) |
| self.assertTrue(dropout_node is not None) |
| self.assertTrue(dropout_node.args[2]) |
| |
| # Move to eval |
| torch.ao.quantization.move_exported_model_to_eval(m) |
| |
| # Assert that dropout op is now in eval mode |
| dropout_node = self._get_node(m, target) |
| self.assertTrue(dropout_node is not None) |
| self.assertTrue(not dropout_node.args[2]) |
| |
| # Move back to train |
| torch.ao.quantization.move_exported_model_to_train(m) |
| |
| # Assert that dropout op is now in train mode again |
| dropout_node = self._get_node(m, target) |
| self.assertTrue(dropout_node is not None) |
| self.assertTrue(dropout_node.args[2]) |
| |
| def test_move_exported_model_dropout(self): |
| self._test_move_exported_model_dropout(inplace=False) |
| |
| def test_move_exported_model_dropout_inplace(self): |
| self._test_move_exported_model_dropout(inplace=True) |
| |
| def test_move_exported_model_bn(self): |
| """ |
| Test switching batch_norm behavior between train and eval modes using |
| `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, x): |
| return self.bn(x) |
| |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| m = M().train() |
| m = capture_pre_autograd_graph(m, example_inputs) |
| |
| # Assert that batch norm op exists and is in train mode |
| bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default) |
| self.assertTrue(bn_node is not None) |
| self.assertTrue(bn_node.args[5]) |
| |
| # Move to eval |
| torch.ao.quantization.move_exported_model_to_eval(m) |
| |
| # Assert that batch norm op is now in eval mode |
| bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit_no_training.default) |
| self.assertTrue(bn_node is not None) |
| |
| # Move to train |
| torch.ao.quantization.move_exported_model_to_train(m) |
| |
| # Assert that batch norm op is now in train mode again |
| bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default) |
| self.assertTrue(bn_node is not None) |
| self.assertTrue(bn_node.args[5]) |
| |
| def test_disallow_eval_train(self): |
| m = TestHelperModules.ConvWithBNRelu(relu=True) |
| example_inputs = (torch.rand(3, 3, 5, 5),) |
| |
| # Before export: this is OK |
| m.eval() |
| m.train() |
| |
| # After export: this is not OK |
| m = capture_pre_autograd_graph(m, example_inputs) |
| with self.assertRaises(NotImplementedError): |
| m.eval() |
| with self.assertRaises(NotImplementedError): |
| m.train() |
| |
| # After prepare: still not OK |
| quantizer = XNNPACKQuantizer() |
| m = prepare_qat_pt2e(m, quantizer) |
| with self.assertRaises(NotImplementedError): |
| m.eval() |
| with self.assertRaises(NotImplementedError): |
| m.train() |
| |
| # After convert: still not OK |
| m = convert_pt2e(m, fold_quantize=True) |
| with self.assertRaises(NotImplementedError): |
| m.eval() |
| with self.assertRaises(NotImplementedError): |
| m.train() |
| |
| def test_reentrant(self): |
| """Test we can safely call quantization apis multiple times""" |
| m = TestHelperModules.ConvBnReLU2dAndLinearReLU() |
| example_inputs = (torch.randn(3, 3, 10, 10),) |
| |
| quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_qat=True)) |
| m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs) |
| m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) |
| m(*example_inputs) |
| m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True) |
| |
| quantizer = XNNPACKQuantizer().set_module_type(torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)) |
| m = capture_pre_autograd_graph(m, example_inputs) |
| m = prepare_pt2e(m, quantizer) |
| m = convert_pt2e(m, fold_quantize=True) |
| |
| node_occurrence = { |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4, |
| # one for weight |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 5, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, |
| } |
| node_list = [ |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.aten.conv2d.default), |
| ns.call_function(torch.ops.aten.relu.default), |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default), |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), |
| ns.call_function(torch.ops.aten.linear.default), |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default), |
| ] |
| self.checkGraphModuleNodes( |
| m, expected_node_occurrence=node_occurrence, expected_node_list=node_list |
| ) |
| |
| def test_groupwise_per_channel_quant(self): |
| m = TestHelperModules.GroupwiseConv2d() |
| quantizer = XNNPACKQuantizer() |
| operator_config = get_symmetric_quantization_config(is_per_channel=True) |
| quantizer.set_global(operator_config) |
| example_inputs = m.example_inputs() |
| m = self._quantize(m, quantizer, example_inputs) |
| # make sure it runs |
| m(*example_inputs) |
| |
| def test_observer_callback(self): |
| from torch.library import Library, impl |
| test_lib = Library("test_int4", "DEF") |
| test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| |
| @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") |
| def quantize_per_tensor_int4( |
| input: torch.Tensor, |
| scale: float, |
| zero_point: int, |
| ) -> torch.Tensor: |
| inv_scale = 1.0 / scale |
| return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) |
| |
| test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| |
| @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") |
| def dequantize_per_tensor_int4( |
| input: torch.Tensor, |
| scale: float, |
| zero_point: int, |
| ) -> torch.Tensor: |
| return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale |
| |
| from torch.ao.quantization.observer import ObserverBase |
| |
| class Int4Observer(ObserverBase): |
| def __init__(self, *args, **kwargs): |
| # just faking a dtype here |
| super().__init__(dtype=torch.int8) |
| |
| def forward(self, x): |
| return x |
| |
| def calculate_qparams(self, **kwargs): |
| pass |
| |
| def convert(self, model: torch.fx.GraphModule, observer_node: Node): |
| with model.graph.inserting_before(observer_node): |
| q_node = model.graph.call_function( |
| torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) |
| dq_node = model.graph.call_function( |
| torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) |
| observer_node.replace_all_uses_with(dq_node) |
| model.graph.erase_node(observer_node) |
| |
| class BackendAQuantizer(Quantizer): |
| def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.add.Tensor |
| ): |
| input_act0 = node.args[0] |
| assert isinstance(input_act0, Node) |
| input_act1 = node.args[1] |
| assert isinstance(input_act1, Node) |
| |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=Int4Observer, |
| ) |
| node.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act0: act_qspec, |
| input_act1: act_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| def validate(self, model: torch.fx.GraphModule) -> None: |
| pass |
| |
| |
| class M(torch.nn.Module): |
| def forward(self, x1, x2): |
| return x1 + x2 |
| |
| example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),) |
| node_occurrence = { |
| # two for input of the first conv, one for output for the first conv |
| torch.ops.test_int4.quantize_per_tensor_int4: 3, |
| torch.ops.test_int4.dequantize_per_tensor_int4: 3, |
| } |
| node_list = [ |
| torch.ops.test_int4.dequantize_per_tensor_int4, |
| torch.ops.test_int4.dequantize_per_tensor_int4, |
| torch.ops.aten.add.Tensor, |
| torch.ops.test_int4.quantize_per_tensor_int4, |
| ] |
| self._test_quantizer( |
| M().eval(), |
| example_inputs, |
| BackendAQuantizer(), |
| node_occurrence, |
| node_list, |
| ) |
| |
| def test_speed(self): |
| import time |
| |
| def dynamic_quantize_pt2e(model, example_inputs): |
| torch._dynamo.reset() |
| model = capture_pre_autograd_graph(model, example_inputs) |
| # Per channel quantization for weight |
| # Dynamic quantization for activation |
| # Please read a detail: https://fburl.com/code/30zds51q |
| embedding_quantizer = EmbeddingQuantizer() |
| dynamic_quantizer = XNNPACKQuantizer() |
| operator_config_dynamic = get_symmetric_quantization_config( |
| is_per_channel=True, is_dynamic=True |
| ) |
| dynamic_quantizer.set_global(operator_config_dynamic) |
| composed_quantizer = ComposableQuantizer([embedding_quantizer, dynamic_quantizer]) |
| prev = time.time() |
| model = prepare_qat_pt2e(model, composed_quantizer) |
| cur = time.time() |
| # print("prepare time:", cur - prev) |
| # Without Calibraiton, scale/zero value will have an initialized value of 1.0 |
| # Per channel quantization needs a proper scale/zero shape/value to work properly. |
| # So we need to run calibration before converting to quantized model. |
| model(*example_inputs) |
| prev = time.time() |
| model = convert_pt2e(model) |
| cur = time.time() |
| # uncomment to see the time |
| # print("convert time:", cur - prev) |
| return model |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| m = M().eval() |
| example_inputs = (torch.randn(5, 5),) |
| _ = dynamic_quantize_pt2e(m, example_inputs) |