| from torch.fx import GraphModule |
| |
| from .pt2e.prepare import prepare |
| from .pt2e._propagate_annotation import propagate_annotation |
| from .pt2e.qat_utils import ( |
| _fuse_conv_bn_qat, |
| _fold_conv_bn_qat, |
| ) |
| from .pt2e.utils import ( |
| _get_node_name_to_scope, |
| _fuse_conv_bn_, |
| _rearrange_weight_observer_for_decomposed_linear, |
| _replace_dropout_for_eval, |
| ) |
| from .pt2e.representation import reference_representation_rewrite |
| from .fx.prepare import prepare as fx_prepare |
| from .quantize_fx import _convert_to_reference_decomposed_fx |
| from torch.ao.quantization import QConfigMapping |
| # TODO: move quantizer to torch.ao.quantization |
| from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401 |
| OperatorConfig, |
| OperatorPatternType, |
| QuantizationConfig, |
| Quantizer, |
| QuantizationSpecBase, |
| QuantizationSpec, |
| FixedQParamsQuantizationSpec, |
| SharedQuantizationSpec, |
| DerivedQuantizationSpec, |
| QuantizationAnnotation, |
| XNNPACKQuantizer, |
| EmbeddingQuantizer, |
| ComposableQuantizer, |
| ) |
| from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401 |
| get_bias_qspec, |
| get_input_act_qspec, |
| get_output_act_qspec, |
| get_weight_qspec, |
| ) |
| from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401 |
| get_symmetric_quantization_config, |
| ) |
| from torch.ao.quantization.backend_config import BackendConfig |
| |
| from typing import Any, Tuple |
| |
| __all__ = [ |
| "prepare_pt2e", |
| "prepare_qat_pt2e", |
| "convert_pt2e", |
| ] |
| |
| def _prepare_pt2e_deprecated( |
| model: GraphModule, |
| qconfig_mapping: QConfigMapping, |
| example_inputs: Tuple[Any, ...], |
| backend_config: BackendConfig, |
| ) -> GraphModule: |
| node_name_to_scope = _get_node_name_to_scope(model) |
| |
| # TODO: check qconfig_mapping to make sure conv and bn are both configured |
| # to be quantized before fusion |
| # TODO: (maybe) rewrite this with subgraph_rewriter |
| _fuse_conv_bn_(model) |
| model = fx_prepare( |
| model, |
| qconfig_mapping, |
| False, # is_qat |
| node_name_to_scope, |
| example_inputs, |
| backend_config=backend_config |
| ) |
| |
| # TODO: remove hack when we have better support for pattern matching |
| # move around the observer for addmm |
| _rearrange_weight_observer_for_decomposed_linear(model) |
| return model |
| |
| def prepare_pt2e( |
| model: GraphModule, |
| quantizer: Quantizer, |
| ) -> GraphModule: |
| node_name_to_scope = _get_node_name_to_scope(model) |
| # TODO: check qconfig_mapping to make sure conv and bn are both configured |
| # to be quantized before fusion |
| # TODO: (maybe) rewrite this with subgraph_rewriter |
| _fuse_conv_bn_(model) |
| quantizer.annotate(model) |
| quantizer.validate(model) |
| propagate_annotation(model) |
| model = prepare(model, node_name_to_scope, is_qat=False) |
| return model |
| |
| def prepare_qat_pt2e( |
| model: GraphModule, |
| quantizer: Quantizer, |
| ) -> GraphModule: |
| node_name_to_scope = _get_node_name_to_scope(model) |
| quantizer.annotate(model) |
| quantizer.validate(model) |
| propagate_annotation(model) |
| # Perform fusion after annotate to avoid quantizing ops in the new |
| # subgraph that don't need to be quantized |
| # TODO: only fuse if conv and bn are both configured to be quantized |
| _fuse_conv_bn_qat(model) |
| model = prepare(model, node_name_to_scope, is_qat=True) |
| # TODO: remove hack when we have better support for pattern matching |
| # move around the observer for addmm |
| _rearrange_weight_observer_for_decomposed_linear(model) |
| return model |
| |
| def convert_pt2e( |
| model: GraphModule, |
| use_reference_representation: bool = False, |
| ) -> GraphModule: |
| # TODO: Handle this in export itself, outside of quantization |
| # See https://github.com/pytorch/pytorch/issues/103681. |
| _replace_dropout_for_eval(model) |
| model = _convert_to_reference_decomposed_fx(model) |
| model = _fold_conv_bn_qat(model) |
| if use_reference_representation: |
| model = reference_representation_rewrite(model) |
| return model |