blob: b899ac3fa11702f22c31db8b2707d3c4536c050d [file] [log] [blame]
from torch import nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
from .stubs import QuantStub, DeQuantStub
# Map for swapping float module to quantized ones
DEFAULT_MODULE_MAPPING = {
nn.Linear: nnq.Linear,
nn.ReLU: nnq.ReLU,
nn.ReLU6: nnq.ReLU6,
nn.Hardswish: nnq.Hardswish,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.LayerNorm: nnq.LayerNorm,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
nnqat.Hardswish: nnq.Hardswish,
}
# Map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPING = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
nn.Hardswish: nnqat.Hardswish,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
# Map for swapping dynamic modules
DEFAULT_DYNAMIC_MODULE_MAPPING = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
}
# Whitelist for propagating the qconfig
_EXCLUDE_QCONFIG_PROPAGATE_LIST = {
DeQuantStub,
}
_INCLUDE_QCONFIG_PROPAGATE_LIST = {
nn.Sequential,
}
DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST = (
(set(DEFAULT_MODULE_MAPPING.keys()) |
set(DEFAULT_QAT_MODULE_MAPPING.keys()) |
set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_WHITE_LIST = (
set(DEFAULT_MODULE_MAPPING.values())
| set(DEFAULT_QAT_MODULE_MAPPING.values())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.values())
| set(DEFAULT_MODULE_MAPPING.keys())
| set(DEFAULT_QAT_MODULE_MAPPING.keys())
| set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST