[Quant][fx] Fix get_default_qconfig_dict for fused modules
Summary: Calling `prepare_fx` with `get_default_qconfig_dict`
failed for models with fused modules, such as `ConvReLU2d`.
This commit fixes this by adding qconfig entries for ReLU
and BatchNorm as well.
Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qconfig_dict_with_fused_modules
Reviewers: jerryzh168
Subscribers: jerryzh168, vkuzo
Issue: https://github.com/pytorch/pytorch/issues/75825
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75838
Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index a0eedae..de3cb11 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -1778,6 +1778,49 @@
self.checkGraphModuleNodes(m, expected_node_list=node_list)
+ def test_qconfig_dict_with_fused_modules(self):
+ class LinearReLUModel(torch.nn.Module):
+ def __init__(self, relu):
+ super(LinearReLUModel, self).__init__()
+ self.linear = torch.nn.Linear(3, 3)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.relu(x)
+ return x
+
+ class ConvReLUModel(torch.nn.Module):
+ def __init__(self, relu):
+ super(ConvReLUModel, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 3, 3)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.relu(x)
+ return x
+
+ class ConvBnReLUModel(torch.nn.Module):
+ def __init__(self, relu):
+ super(ConvBnReLUModel, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 3, 3)
+ self.bn = torch.nn.BatchNorm1d(3)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+ for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
+ for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
+ m = model(relu).eval()
+ qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
+ # should not crash as in https://github.com/pytorch/pytorch/issues/75825
+ prepare_fx(m, qconfig_dict)
+
def test_qconfig_dict_validity(self):
r"""
Verifies that if a user passes an invalid key or makes a typo when
diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py
index 94e9646..c093d71 100644
--- a/torch/ao/quantization/qconfig.py
+++ b/torch/ao/quantization/qconfig.py
@@ -352,7 +352,13 @@
(torch.nn.functional.conv_transpose1d, qconfig_transpose),
(torch.nn.functional.conv_transpose2d, qconfig_transpose),
(torch.nn.functional.conv_transpose3d, qconfig_transpose),
- (torch.nn.functional.linear, qconfig)]}
+ (torch.nn.functional.linear, qconfig),
+ (torch.nn.ReLU, qconfig),
+ (torch.nn.functional.relu, qconfig),
+ (torch.relu, qconfig),
+ (torch.nn.BatchNorm1d, qconfig),
+ (torch.nn.BatchNorm2d, qconfig),
+ (torch.nn.BatchNorm3d, qconfig)]}
def get_default_qconfig_dict(backend='fbgemm', version=0):
qconfig = get_default_qconfig(backend, version)