blob: 451f18977e9ad996acfbc6b34951b25d372fb456 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import functools
from typing import Any, Callable, Dict, Optional
import torch
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
propagate_annotation,
QuantizationConfig,
)
from torch.fx import Node
__all__ = [
"VulkanQuantizer",
"get_weight_quantization_config",
]
@functools.lru_cache
def get_weight_quantization_config(
is_per_channel: bool = True,
weight_qmin: int = -128,
weight_qmax: int = 127,
) -> QuantizationConfig:
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
)
extra_args: Dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=weight_qmin,
quant_max=weight_qmax,
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
quantization_config = QuantizationConfig(
input_activation=None,
output_activation=None,
weight=weight_quantization_spec,
bias=None,
is_qat=False,
)
return quantization_config
_SUPPORTED_OPS = [
"linear",
]
class VulkanQuantizer(Quantizer):
def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
self.global_config = quantization_config
return self
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Transforms scalar values to tensor attributes"""
return _convert_scalars_to_attrs(model)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
# currently only support static quant on Vulkan
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model
def _annotate_all_static_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
if quantization_config is None:
return model
for op in _SUPPORTED_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
self._annotate_all_static_patterns(
model,
self.global_config,
)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass