| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from dataclasses import dataclass |
| from typing import List |
| |
| import torch |
| from executorch.backends.example.example_operators.ops import module_to_annotator |
| from torch import fx |
| from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver |
| from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions |
| from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
| from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig |
| |
| |
| def get_uint8_tensor_spec(observer_or_fake_quant_ctr): |
| return QuantizationSpec( |
| dtype=torch.uint8, |
| quant_min=0, |
| quant_max=255, |
| qscheme=torch.per_tensor_affine, |
| is_dynamic=False, |
| observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| ) |
| |
| |
| @dataclass |
| class ExampleQuantConfig: |
| input_quant_spec: QuantizationSpec |
| output_quant_spec: QuantizationSpec |
| weight_quant_spec: QuantizationSpec |
| bias_quant_spec: QuantizationSpec |
| |
| |
| default_static_config = ExampleQuantConfig( |
| get_uint8_tensor_spec(HistogramObserver), |
| get_uint8_tensor_spec(HistogramObserver), |
| get_uint8_tensor_spec(MinMaxObserver), |
| # pyre-fixme[6]: Incompatible parameter type [6]: In call `ExampleQuantConfig.__init__`, for 4th positional argument, expected `QuantizationSpec` but got `None`. |
| None, # #bias quantization can be configured here or done in a pass later on. |
| ) |
| |
| |
| def check_for_outside_users(partitions) -> bool: |
| """ |
| Make sure that all the users of this partiton are within the delegatable subgraph, |
| except the last partition. If we quantize partitions that have users outside this |
| subgraph then delegation of this partition to the backend will not be possible. |
| """ |
| for source_partition in partitions[:-1]: |
| if len(source_partition.output_nodes) != 1: |
| return True |
| if len(source_partition.output_nodes[0].users) != 1: |
| return True |
| return False |
| |
| |
| class ExampleQuantizer(Quantizer): |
| def __init__(self, quantizer_supported_modules=None, quant_config=None): |
| super().__init__() |
| if quantizer_supported_modules is not None: |
| self.quantizer_supported_modules = quantizer_supported_modules |
| for module in self.quantizer_supported_modules: |
| if module not in module_to_annotator.keys(): |
| assert 0, f"{module} is not supported by this quantizer" |
| else: |
| self.quantizer_supported_modules = module_to_annotator.keys() |
| if quant_config is not None: |
| self.quant_config = quant_config |
| else: |
| self.quant_config = default_static_config |
| |
| def annotate(self, model): |
| for supported_modules in self.quantizer_supported_modules: |
| # print("supported modules: ", supported_modules) |
| fused_partitions = find_sequential_partitions( |
| model, |
| list(supported_modules), |
| ) |
| |
| for partitions in fused_partitions: |
| if check_for_outside_users(partitions): |
| continue |
| |
| source_module_list = () |
| for partition in partitions: |
| source_module_list += (partition,) |
| |
| annotator = module_to_annotator[supported_modules].annotate_handle |
| annotator(partitions, self.quant_config) |
| |
| return model |
| |
| def validate(self, model: fx.GraphModule) -> None: |
| pass |
| |
| @classmethod |
| def get_supported_operators(cls) -> List[OperatorConfig]: |
| return [] |