[Quant][fx] Fix get_default_qconfig_dict for fused modules

Summary: Calling `prepare_fx` with `get_default_qconfig_dict`
failed for models with fused modules, such as `ConvReLU2d`.
This commit fixes this by adding qconfig entries for ReLU
and BatchNorm as well.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qconfig_dict_with_fused_modules

Reviewers: jerryzh168

Subscribers: jerryzh168, vkuzo

Issue: https://github.com/pytorch/pytorch/issues/75825

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75838

Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index a0eedae..de3cb11 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -1778,6 +1778,49 @@
         self.checkGraphModuleNodes(m, expected_node_list=node_list)
 
 
+    def test_qconfig_dict_with_fused_modules(self):
+        class LinearReLUModel(torch.nn.Module):
+            def __init__(self, relu):
+                super(LinearReLUModel, self).__init__()
+                self.linear = torch.nn.Linear(3, 3)
+                self.relu = relu
+
+            def forward(self, x):
+                x = self.linear(x)
+                x = self.relu(x)
+                return x
+
+        class ConvReLUModel(torch.nn.Module):
+            def __init__(self, relu):
+                super(ConvReLUModel, self).__init__()
+                self.conv = torch.nn.Conv1d(3, 3, 3)
+                self.relu = relu
+
+            def forward(self, x):
+                x = self.conv(x)
+                x = self.relu(x)
+                return x
+
+        class ConvBnReLUModel(torch.nn.Module):
+            def __init__(self, relu):
+                super(ConvBnReLUModel, self).__init__()
+                self.conv = torch.nn.Conv1d(3, 3, 3)
+                self.bn = torch.nn.BatchNorm1d(3)
+                self.relu = relu
+
+            def forward(self, x):
+                x = self.conv(x)
+                x = self.bn(x)
+                x = self.relu(x)
+                return x
+
+        for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
+            for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
+                m = model(relu).eval()
+                qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
+                # should not crash as in https://github.com/pytorch/pytorch/issues/75825
+                prepare_fx(m, qconfig_dict)
+
     def test_qconfig_dict_validity(self):
         r"""
         Verifies that if a user passes an invalid key or makes a typo when
diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py
index 94e9646..c093d71 100644
--- a/torch/ao/quantization/qconfig.py
+++ b/torch/ao/quantization/qconfig.py
@@ -352,7 +352,13 @@
                         (torch.nn.functional.conv_transpose1d, qconfig_transpose),
                         (torch.nn.functional.conv_transpose2d, qconfig_transpose),
                         (torch.nn.functional.conv_transpose3d, qconfig_transpose),
-                        (torch.nn.functional.linear, qconfig)]}
+                        (torch.nn.functional.linear, qconfig),
+                        (torch.nn.ReLU, qconfig),
+                        (torch.nn.functional.relu, qconfig),
+                        (torch.relu, qconfig),
+                        (torch.nn.BatchNorm1d, qconfig),
+                        (torch.nn.BatchNorm2d, qconfig),
+                        (torch.nn.BatchNorm3d, qconfig)]}
 
 def get_default_qconfig_dict(backend='fbgemm', version=0):
     qconfig = get_default_qconfig(backend, version)