blob: c6c6d5b3932ddc9c5c952879739997f3133490e8 [file] [log] [blame]
import torch
from torch.fx.graph import (
Node,
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)
from ..quantization_mappings import (
get_static_quant_module_class,
get_dynamic_quant_module_class,
get_quantized_operator,
)
from ..utils import (
get_swapped_custom_module_class,
activation_is_statically_quantized,
activation_is_int8_quantized,
weight_is_statically_quantized,
get_qconfig_dtypes,
)
from .pattern_utils import (
register_quant_pattern,
mark_input_output_not_observed,
)
from .utils import (
_parent_name,
all_node_args_have_no_tensors,
quantize_node,
get_per_tensor_qparams,
get_linear_prepack_op_for_dtype,
create_qparam_nodes,
get_qconv_prepack_op,
get_qconv_op,
)
from .quantization_types import QuantizerCls
from abc import ABC, abstractmethod
import operator
import warnings
from typing import Any, Callable, Dict, Union, Optional, Tuple, List
# -------------------------
# Pattern Registrations
# -------------------------
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
# Base Pattern Handler
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer: QuantizerCls, node: Node):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.num_tensor_args = len(node.args)
self.all_node_args_are_tensors = True
@abstractmethod
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
return NotImplemented
# Binary op configs
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
# tuple (activation_dtype, weight_dtype, compute_dtype)
# these are supported types for common binary ops like add/mul etc.
binary_op_all_dtypes = [
(torch.quint8, torch.qint8, None),
(torch.float16, torch.float16, None),
]
binary_op_float16_dtypes = [
(torch.float16, torch.float16, None)
]
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
operator.add: binary_op_all_dtypes,
torch.add: binary_op_all_dtypes,
operator.mul: binary_op_all_dtypes,
torch.mul: binary_op_all_dtypes,
torch.bmm: binary_op_float16_dtypes,
torch.sub: binary_op_float16_dtypes,
operator.sub: binary_op_float16_dtypes,
torch.div: binary_op_float16_dtypes,
operator.truediv: binary_op_float16_dtypes,
torch.sum: binary_op_float16_dtypes
}
@register_quant_pattern(operator.add)
@register_quant_pattern(operator.sub)
@register_quant_pattern(operator.mul)
@register_quant_pattern(operator.truediv)
@register_quant_pattern(torch.add)
@register_quant_pattern(torch.sub)
@register_quant_pattern(torch.mul)
@register_quant_pattern(torch.div)
@register_quant_pattern(torch.sum)
@register_quant_pattern(torch.bmm)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.ReLU, torch.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class BinaryOpQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
self.binary_op_node = node
self.binary_op = node.target
# determine how many of the first two args are Tensors (versus scalars)
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
self.num_tensor_args = 0
cache_for_no_tensor_check: Dict[Node, bool] = dict()
for arg_idx in range(len(self.binary_op_node.args)):
arg = self.binary_op_node.args[arg_idx]
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, quantizer.modules, cache_for_no_tensor_check)):
self.num_tensor_args += 1
self.all_node_args_are_tensors = \
(self.num_tensor_args == len(self.binary_op_node.args))
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
operator.add: torch.ops.quantized.add,
torch.add: torch.ops.quantized.add,
operator.mul: torch.ops.quantized.mul,
torch.mul: torch.ops.quantized.mul,
}
qbin_relu_op_mapping: Dict[Union[Callable, str], Callable] = {
operator.add: torch.ops.quantized.add_relu,
torch.add: torch.ops.quantized.add_relu,
operator.mul: torch.ops.quantized.mul_relu,
torch.mul: torch.ops.quantized.mul_relu,
}
# corresponding quantized op
self.quantized_binary_op: Optional[Callable] = None
if self.binary_op in qbin_op_mapping:
self.quantized_binary_op = qbin_relu_op_mapping[self.binary_op] \
if self.relu_node is not None \
else qbin_op_mapping[self.binary_op] # type: ignore
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in binary_op_supported_dtypes[self.binary_op]:
warnings.warn(
"dtype combination: {} is not "
"supported by {} "
"supported dtype combinations are: {}".format(dtypes, self.binary_op, binary_op_supported_dtypes[self.binary_op]))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
if dtypes in [(torch.quint8, torch.qint8, None)]:
assert self.quantized_binary_op is not None
if self.num_tensor_args == 1:
# add/mul scalar
first_arg = self.binary_op_node.args[0]
cache_for_no_tensor_check: Dict[Node, bool] = dict()
if isinstance(first_arg, Node) and (
not all_node_args_have_no_tensors(
first_arg, quantizer.modules, cache_for_no_tensor_check)):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op,
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
else:
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.binary_op_node.kwargs}
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op, add_args, kwargs)
return op
else:
assert dtypes == (torch.float16, torch.float16, None)
# TODO (refactor) this is duplicated, maybe have a helper function
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
@register_quant_pattern(torch.cat)
class CatQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if not self.all_node_args_are_tensors:
return NotImplemented
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale_arg, 'zero_point': zero_point_arg}
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv1d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.functional.conv3d)
# TODO: add qat.Conv1d
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.qat.Conv3d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU3d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
class ConvReluQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
self.conv_node = node
if node.op == "call_module":
self.conv = quantizer.modules[self.conv_node.target]
elif node.op == "call_function":
self.conv = node.target
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.quint8, torch.qint8, None),
]
# TODO: is_reference option for conv module
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Conv "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False))
relu_args = [conv_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
activation_int8_quantized = activation_is_int8_quantized(qconfig)
if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
cur_idx = quantizer.activation_post_process_indexes[node.name]
self.conv.activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
# 2. select quantized class
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
else: # call_function
assert self.conv_node.op == "call_function"
if is_reference:
args = load_arg(quantized=[0, 1])(self.conv_node.args)
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", self.conv, args, kwargs)
if self.relu_node:
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
if activation_int8_quantized:
root_module = quantizer.modules['']
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
act_post_process_node = self.relu_node if self.relu_node else self.conv_node
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
quantizer.activation_post_process_indexes[act_post_process_name] += 1
return quantize_node(
quantizer, op_out, activation_post_process,
act_post_process_node, is_input=False)
else:
# output for dynamically quantized conv op is not quantized
return op_out
else:
assert len(self.conv_node.args) >= 7, \
"only conv2d calls with all arguments specified is supported right now in is_reference=False option"
args = load_arg(quantized=[0, 1])(self.conv_node.args)
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = tuple([weight] + list(other_args))
prepack_op = get_qconv_prepack_op(self.conv)
packed_weight = quantizer.quantized_graph.create_node(
"call_function", prepack_op, prepack_args, {})
assert activation_int8_quantized, \
"currently only static quantization is supported for conv"
# construct conv input
if activation_int8_quantized:
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
quantizer.activation_post_process_indexes[act_post_process_name] += 1
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
op = quantizer.quantized_graph.create_node(
'call_function', qconv_op, qconv_args, kwargs)
# Store the name of the fused op to get the path of node after fusion as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.conv_node.name]
return op
else:
# conv2d_dyanmic branch
raise Exception("Only static quant is supported for conv")
# handle linear, maybe followed by relu
@register_quant_pattern(torch.nn.Linear)
@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLUQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
self.linear_node = node
if node.op == 'call_module':
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
# dynamic float32 (quint8) qint8
# weight_only float32 float16
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.quint8, torch.qint8, None),
(torch.float32, torch.qint8, torch.quint8),
(torch.float32, torch.float16, None),
# static float16 quantization
(torch.float16, torch.float16, None),
]
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Linear "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
activation_int8_quantized = activation_is_int8_quantized(qconfig)
weight_dtype = dtypes[1]
# TODO: reference_model option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach output activation post process to linear module
if node.name in quantizer.activation_post_process_map:
# this is the static quantization case
cur_idx = quantizer.activation_post_process_indexes[node.name]
output_activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
else:
output_activation_post_process = None
if output_activation_post_process:
self.linear.activation_post_process = output_activation_post_process
# 2. select corresponding quantized linear class for the float linear class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = nnq.Linear if activation_int8_quantized else nnqd.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
assert activation_int8_quantized, \
'Only int8 static quantization is supported for LinearReLU'
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
# activation needs to be quantized for static quantization
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target,
(load_arg(quantized=activation_int8_quantized)(self.linear_node.args[0]),), {})
else: # call_function
assert self.linear_node.op == 'call_function'
if is_reference:
quantized_input_idxs = []
if activation_int8_quantized:
quantized_input_idxs.append(0)
if weight_is_statically_quantized(qconfig):
quantized_input_idxs.append(1)
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.linear, args, kwargs)
if self.relu_node:
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
if activation_int8_quantized:
# quantize output for statically quantized linear op
root_module = quantizer.modules['']
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
act_post_process_node = self.relu_node if self.relu_node else self.linear_node
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
quantizer.activation_post_process_indexes[act_post_process_name] += 1
return quantize_node(
quantizer,
op_out,
activation_post_process,
act_post_process_node,
is_input=False)
else:
# output for dynamically quantized linear op is not quantized
return op_out
else: # non-reference option
# prepacking weights for static int8 quant and dynamic quant
if dtypes != (torch.float16, torch.float16, None):
# linear args
# (x, weight, bias, ...)
weight_quantized = weight_is_statically_quantized(qconfig)
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])
# get other arguments
kwargs = {**load_arg(quantized=False)(self.linear_node.kwargs)}
# pack weight
bias = None
# all args after bias, including bias
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = (linear_weight, bias)
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', prepack_op, prepack_args, {})
# construct linear input
if activation_int8_quantized:
qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
quantizer.activation_post_process_indexes[act_post_process_name] += 1
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
op = quantizer.quantized_graph.create_node(
"call_function", qlinear_op, qlinear_args, kwargs)
# Store the name of the fused op to get the path of node after fusion as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.linear_node.name]
return op
elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
(torch.float32, torch.float16, None)]:
# choose linear dynamic or linear dynamic fp16 op based on weight dtype
qlinear_op = torch.ops.quantized.linear_dynamic \
if weight_dtype == torch.qint8 \
else torch.ops.quantized.linear_dynamic_fp16
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qlinear_args = (linear_input, packed_weight) # type: ignore
op_out = quantizer.quantized_graph.create_node(
"call_function", qlinear_op, qlinear_args, kwargs)
# Store the name of the dynamic op to get the path of node after replacement as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op_out.name] = quantizer.node_name_to_scope[self.linear_node.name]
if self.relu_node:
op_out = quantizer.quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
return op_out
else:
assert dtypes == (torch.float16, torch.float16, None)
# TODO (refactor) this is duplicated, maybe have a helper function
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
@register_quant_pattern(torch.nn.BatchNorm2d)
@register_quant_pattern(torch.nn.BatchNorm3d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
class BatchNormQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
assert node.op == 'call_module'
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
cur_idx = quantizer.activation_post_process_indexes[node.name]
self.bn.activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.bn_node.target,
load_arg(quantized=[0])(self.bn_node.args),
load_arg(quantized=False)(self.bn_node.kwargs))
@register_quant_pattern(torch.nn.Embedding)
@register_quant_pattern(torch.nn.EmbeddingBag)
@mark_input_output_not_observed()
class EmbeddingQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation | weight | activation_compute_type
# weight_only | float32 | quint8 | None
# weight_only | float32 | quint4x2 | None
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.float32, torch.quint8, None),
(torch.float32, torch.quint4x2, None),
]
assert node.op == 'call_module'
emb_node = node
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Embedding/EmbeddingBag, "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
emb = quantizer.modules[emb_node.target]
qemb = get_static_quant_module_class(type(emb))
quantized = qemb.from_float(emb)
parent_name, name = _parent_name(emb_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
emb_node.target,
load_arg(quantized=False)(emb_node.args),
load_arg(quantized=False)(emb_node.kwargs))
# TODO (maybe): merge with embedding quantize handler
@register_quant_pattern(torch.nn.GRUCell)
@register_quant_pattern(torch.nn.LSTMCell)
@register_quant_pattern(torch.nn.RNNCell)
@register_quant_pattern(torch.nn.LSTM)
@mark_input_output_not_observed()
class RNNDynamicQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation | weight | activation_compute_type
# dynamic | float32 | qint8 | quint8
# dynamic | float32 | float16 | None
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.float32, torch.qint8, torch.quint8),
(torch.float32, torch.float16, None),
]
assert node.op == 'call_module'
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Embedding/EmbeddingBag, "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
module = quantizer.modules[node.target]
qmodule_cls = get_dynamic_quant_module_class(type(module))
qmodule = qmodule_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, qmodule)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=False)(node.args),
load_arg(quantized=False)(node.kwargs))
ARGS_TO_SKIP = {
torch._ops.ops.quantized.hardswish: ['inplace'],
torch._ops.ops.quantized.instance_norm:
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
}
@register_quant_pattern(torch.nn.ConvTranspose1d)
@register_quant_pattern(torch.nn.ConvTranspose2d)
@register_quant_pattern(torch.nn.ELU)
@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.Hardswish)
@register_quant_pattern(torch.nn.InstanceNorm1d)
@register_quant_pattern(torch.nn.InstanceNorm2d)
@register_quant_pattern(torch.nn.InstanceNorm3d)
@register_quant_pattern(torch.nn.LayerNorm)
@register_quant_pattern(torch.nn.SiLU)
@register_quant_pattern(torch.nn.functional.hardswish)
@register_quant_pattern(torch.nn.functional.instance_norm)
@register_quant_pattern(torch.nn.functional.layer_norm)
@register_quant_pattern(torch.nn.functional.leaky_relu)
@register_quant_pattern(torch.nn.functional.silu)
class DefaultNodeQuantizeHandler(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
if node.op == "call_function" or node.op == "call_method":
self.op = node.target
elif node.op == "call_module":
self.op = type(quantizer.modules[node.target])
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if not self.all_node_args_are_tensors:
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
'call_function are handled in DefaultNode'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
all_dtypes = [
(torch.quint8, torch.qint8, None),
(torch.float16, torch.float16, None)
]
int8_dtypes = [
(torch.quint8, torch.qint8, None)
]
fp16_dtypes = [
(torch.float16, torch.float16, None)
]
supported_dtypes = {
torch.nn.ConvTranspose1d: int8_dtypes,
torch.nn.ConvTranspose2d: int8_dtypes,
torch.nn.ELU: int8_dtypes,
torch.nn.LeakyReLU: int8_dtypes,
torch.nn.Hardswish: int8_dtypes,
torch.nn.InstanceNorm1d: int8_dtypes,
torch.nn.InstanceNorm2d: int8_dtypes,
torch.nn.InstanceNorm3d: int8_dtypes,
torch.nn.LayerNorm: all_dtypes,
torch.nn.SiLU: fp16_dtypes,
torch.nn.functional.hardswish: int8_dtypes,
torch.nn.functional.instance_norm: int8_dtypes,
torch.nn.functional.layer_norm: all_dtypes,
torch.nn.functional.leaky_relu: int8_dtypes,
torch.nn.functional.silu: fp16_dtypes,
}
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
if dtypes not in supported_dtypes[self.op]:
warnings.warn(
"dtype combination: {} is not "
"supported by {} "
"supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op]))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
# TODO: make helper functions for (torch.quint8, torch.qint8, None)
if dtypes in [(torch.quint8, torch.qint8, None)]:
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(
type(module), additional_static_quant_mapping)
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=[0])(node.args),
load_arg(quantized=False)(node.kwargs))
else:
assert node.op == "call_function"
# call_function
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
assert not isinstance(node.target, str), "Expecting node.target for "
"call_function to be a function instead of a string"
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg, "output_zero_point": zero_point_arg}
if quantized_op in ARGS_TO_SKIP:
args_to_skip = ARGS_TO_SKIP[quantized_op]
for arg in args_to_skip:
if arg in kwargs:
kwargs.pop(arg)
return quantizer.quantized_graph.create_node(
"call_function", quantized_op, args, kwargs)
else:
assert dtypes == (torch.float16, torch.float16, None)
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
@register_quant_pattern(torch.nn.functional.elu)
class ELUQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale_arg, 'output_zero_point': zero_point_arg}
kwargs.pop('inplace')
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)
@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant)
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
if dtypes == (torch.float16, torch.float16, None):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
@register_quant_pattern(torch.nn.AvgPool1d)
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.AvgPool3d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.Hardtanh)
@register_quant_pattern(torch.nn.Identity)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.adaptive_avg_pool1d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
@register_quant_pattern(torch.nn.functional.interpolate)
@register_quant_pattern(torch.nn.functional.max_pool1d)
@register_quant_pattern(torch.nn.functional.max_pool2d)
@register_quant_pattern(torch.nn.functional.max_pool3d)
@register_quant_pattern(torch.nn.functional.relu)
@register_quant_pattern(torch.nn.functional.relu6)
@register_quant_pattern(torch.avg_pool1d)
@register_quant_pattern(torch._C._nn.avg_pool2d)
@register_quant_pattern(torch._C._nn.avg_pool3d)
@register_quant_pattern(torch.chunk)
@register_quant_pattern(torch.clamp)
@register_quant_pattern(torch.flatten)
@register_quant_pattern(torch.transpose)
@register_quant_pattern(torch.max)
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.min)
@register_quant_pattern(torch.repeat_interleave)
@register_quant_pattern(torch.sort)
@register_quant_pattern(torch.squeeze)
@register_quant_pattern(torch.stack)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.floordiv)
@register_quant_pattern(operator.getitem)
@register_quant_pattern('chunk')
@register_quant_pattern('clamp')
@register_quant_pattern('contiguous')
@register_quant_pattern('detach')
@register_quant_pattern('detach_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
@register_quant_pattern('relu')
@register_quant_pattern('relu_')
@register_quant_pattern('repeat')
@register_quant_pattern('repeat_interleave')
@register_quant_pattern('reshape')
@register_quant_pattern('resize_')
@register_quant_pattern('shape')
@register_quant_pattern('size')
@register_quant_pattern('squeeze')
@register_quant_pattern('squeeze_')
@register_quant_pattern('transpose')
@register_quant_pattern('unsqueeze')
@register_quant_pattern('unsqueeze_')
@register_quant_pattern('view')
class CopyNodeQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
# Default quantization handler, used for quantization of input and output
# of quantizable objects (e.g. modules and functionals)
class DefaultQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
assert self.all_node_args_are_tensors
root_module = quantizer.modules['']
cur_idx = quantizer.activation_post_process_indexes[node.name]
activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
return quantize_node(
quantizer,
node, activation_post_process, node, is_input=False)
class CustomModuleQuantizeHandler(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
""" Convert a float custom module to quantized custom module
"""
assert node.op == 'call_module'
assert convert_custom_config_dict is not None
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
assert custom_module_class_mapping is not None
qconfig = quantizer.qconfig_map[node.name]
observed_custom_module = quantizer.modules[node.target]
if activation_is_statically_quantized(qconfig):
assert node.name in quantizer.activation_post_process_map
cur_idx = quantizer.activation_post_process_indexes[node.name]
observed_custom_module.activation_post_process = \
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
quantizer.activation_post_process_indexes[node.name] += 1
quantized_custom_module_class = get_swapped_custom_module_class(
observed_custom_module, custom_module_class_mapping, qconfig)
quantized_custom_module = \
quantized_custom_module_class.from_observed(observed_custom_module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_custom_module)
# hardcoded the qunatized input to be None (take whatever is in the environemnt),
# we can extend this
# if there is a need, e.g. get the indexes of quantized inputs from some
# module attribute like module._QUANTIZED_INPUT_INDEXES
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
class StandaloneModuleQuantizeHandler(QuantizeHandler):
""" Converts an observed standalone module to quantized standalone module
by calling convert_fx on the observed standalone module.
"""
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
assert node.op == 'call_module'
qconfig = quantizer.qconfig_map[node.name]
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
observed_standalone_module = quantizer.modules[node.target]
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist()
quantized_standalone_module = convert(observed_standalone_module, is_reference=is_reference)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
quantizer.modules[node.target] = quantized_standalone_module
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))