[Quant][ONEDNN] Fix weight reorder issue for grouped convolution (#91934)

**Summary**
For onednn quant backend only.
QConv weight may be reordered to another blocked format if input shape is changed at runtime. It's a bug that group info is not retained for such reordering. This may lead to wrong shape of weight after reordering. This PR fixes this bug.

**Test plan**
python test/test_quantization.py -k test_conv_reorder_issue_onednn

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91934
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp
index e86c927..3c6dcd9 100644
--- a/aten/src/ATen/native/quantized/cpu/qconv.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp
@@ -1340,7 +1340,8 @@
             dnnl::prop_kind::forward_inference,
             ideep::u8s8, ideep::engine::cpu_engine());
         get_deconv_cache() = DeconvPrimitiveCache(cache_key, params, b);
-        weights = weights.reorder_if_differ_in(params.pd.weights_desc());
+        auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
+        weights = weights.reorder_if_differ_in(expected_weight_desc);
     });
     if (get_deconv_cache().hit(cache_key)) {
       DeconvParams& params = get_deconv_cache().get_params();
@@ -1372,7 +1373,8 @@
             dnnl::prop_kind::forward_inference,
             ideep::u8s8, ideep::engine::cpu_engine());
         get_conv_cache() = ConvPrimitiveCache(cache_key, params, b);
-        weights = weights.reorder_if_differ_in(params.pd.weights_desc());
+        auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
+        weights = weights.reorder_if_differ_in(expected_weight_desc);
     });
     // If hit, use cached data. If miss, fall back to normal path.
     if (get_conv_cache().hit(cache_key)) {
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 804e162..d38b26d 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -6201,22 +6201,23 @@
             bs = 1
             ic, oc = 128, 512
             kh, kw = 1, 1
-            ih, iw = 28, 28
             bias = None
-            strides, paddings, dilates, groups = (1, 1), (0, 0), (1, 1), 1
-            w = torch.randn((oc, ic, kh, kw))
-            qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
-            x = torch.randn((bs, ic, ih, iw))
-            qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
-            w_packed = torch.ops.quantized.conv2d_prepack(
-                qw, bias, strides, paddings, dilates, groups
-            )
-            torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
-            ih, iw = 5, 4
-            x = torch.randn((bs, ic, ih, iw))
-            qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
-            # The following should pass when input shape is changed
-            torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
+            strides, paddings, dilates = (1, 1), (0, 0), (1, 1)
+            for groups in [1, 2]:
+                ih, iw = 28, 28
+                w = torch.randn((oc * groups, ic, kh, kw))
+                qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
+                x = torch.randn((bs, ic * groups, ih, iw))
+                qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
+                w_packed = torch.ops.quantized.conv2d_prepack(
+                    qw, bias, strides, paddings, dilates, groups
+                )
+                torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
+                ih, iw = 5, 4
+                x = torch.randn((bs, ic * groups, ih, iw))
+                qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
+                # The following should pass when input shape is changed
+                torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
 
     @skipIfNoONEDNN
     def test_conv_transpose_reorder_issue_onednn(self):