blob: 60d166ae4896df3c553b5b64c4aebe8703107574 [file] [log] [blame]
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.ELU: nnq.ELU,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.ConvTranspose1d: nnq.ConvTranspose1d,
nn.ConvTranspose2d: nnq.ConvTranspose2d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.GroupNorm: nnq.GroupNorm,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
}
# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
nn.GRUCell: nnqd.GRUCell,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
}
def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class
def get_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class
def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class
def get_qat_module_mappings():
''' Get module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class
def get_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST
def get_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST
def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op