[Quant][ONEDNN] Fix weight reorder issue for grouped convolution (#91934)
**Summary**
For onednn quant backend only.
QConv weight may be reordered to another blocked format if input shape is changed at runtime. It's a bug that group info is not retained for such reordering. This may lead to wrong shape of weight after reordering. This PR fixes this bug.
**Test plan**
python test/test_quantization.py -k test_conv_reorder_issue_onednn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91934
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp
index e86c927..3c6dcd9 100644
--- a/aten/src/ATen/native/quantized/cpu/qconv.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp
@@ -1340,7 +1340,8 @@
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_deconv_cache() = DeconvPrimitiveCache(cache_key, params, b);
- weights = weights.reorder_if_differ_in(params.pd.weights_desc());
+ auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
+ weights = weights.reorder_if_differ_in(expected_weight_desc);
});
if (get_deconv_cache().hit(cache_key)) {
DeconvParams& params = get_deconv_cache().get_params();
@@ -1372,7 +1373,8 @@
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_conv_cache() = ConvPrimitiveCache(cache_key, params, b);
- weights = weights.reorder_if_differ_in(params.pd.weights_desc());
+ auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
+ weights = weights.reorder_if_differ_in(expected_weight_desc);
});
// If hit, use cached data. If miss, fall back to normal path.
if (get_conv_cache().hit(cache_key)) {
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 804e162..d38b26d 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -6201,22 +6201,23 @@
bs = 1
ic, oc = 128, 512
kh, kw = 1, 1
- ih, iw = 28, 28
bias = None
- strides, paddings, dilates, groups = (1, 1), (0, 0), (1, 1), 1
- w = torch.randn((oc, ic, kh, kw))
- qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
- x = torch.randn((bs, ic, ih, iw))
- qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
- w_packed = torch.ops.quantized.conv2d_prepack(
- qw, bias, strides, paddings, dilates, groups
- )
- torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
- ih, iw = 5, 4
- x = torch.randn((bs, ic, ih, iw))
- qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
- # The following should pass when input shape is changed
- torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
+ strides, paddings, dilates = (1, 1), (0, 0), (1, 1)
+ for groups in [1, 2]:
+ ih, iw = 28, 28
+ w = torch.randn((oc * groups, ic, kh, kw))
+ qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
+ x = torch.randn((bs, ic * groups, ih, iw))
+ qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
+ w_packed = torch.ops.quantized.conv2d_prepack(
+ qw, bias, strides, paddings, dilates, groups
+ )
+ torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
+ ih, iw = 5, 4
+ x = torch.randn((bs, ic * groups, ih, iw))
+ qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
+ # The following should pass when input shape is changed
+ torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
@skipIfNoONEDNN
def test_conv_transpose_reorder_issue_onednn(self):