groupnorm: graph mode static quant support (#39095)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39095
Hooks up groupnorm to graph mode static quant
Test Plan:
```
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_group_norm
```
Imported from OSS
Differential Revision: D21885257
fbshipit-source-id: 3415c4de76181b026d2f5bfebab130fea29e1d1e
diff --git a/test/quantization/test_quantize_script.py b/test/quantization/test_quantize_script.py
index d0439b9..11a2bc0 100644
--- a/test/quantization/test_quantize_script.py
+++ b/test/quantization/test_quantize_script.py
@@ -2141,6 +2141,13 @@
FileCheck().check_not("aten::layer_norm") \
.run(m.graph)
+ def test_group_norm(self):
+ data = [(torch.rand((1, 4, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
+ group_norm = torch.nn.GroupNorm(2, 4)
+ m = self._test_op_impl(group_norm, data, "quantized::group_norm")
+ FileCheck().check_not("aten::group_norm") \
+ .run(m.graph)
+
def test_quantize_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index fd1859f..8f77e94 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -21,6 +21,7 @@
"batch_norm",
"hardswish",
"layer_norm",
+ "group_norm",
};
std::vector<std::string> _static_quantizable_aten_funcs = {
@@ -32,6 +33,7 @@
"matmul",
"hardswish",
"layer_norm",
+ "group_norm",
};
std::vector<std::string> _dynamic_quantizable_call_funcs = {
diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h
index 74c3225..d14e2c6 100644
--- a/torch/csrc/jit/passes/quantization/quantization_patterns.h
+++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h
@@ -607,6 +607,19 @@
%r = quantized::layer_norm(%a_quant, %normalized_shape, %weight, %bias, %eps, %output_scale, %output_zero_point)
return (%r) )";
+ // quantized::group_norm
+ std::string group_norm = R"(
+graph(%a_quant, %num_groups, %weight, %bias, %eps, %cudnn_enabled, %output_scale, %output_zero_point, %scalar_type):
+ %a_dequant = aten::dequantize(%a_quant)
+ %r_gn = aten::group_norm(%a_dequant, %num_groups, %weight, %bias, %eps, %cudnn_enabled)
+ %r = aten::quantize_per_tensor(%r_gn, %output_scale, %output_zero_point, %scalar_type)
+ return (%r) )";
+
+ std::string quantized_group_norm = R"(
+graph(%a_quant, %num_groups, %weight, %bias, %eps, %cudnn_enabled, %output_scale, %output_zero_point, %scalar_type):
+ %r = quantized::group_norm(%a_quant, %num_groups, %weight, %bias, %eps, %output_scale, %output_zero_point)
+ return (%r) )";
+
// ============= General Ops that inherit quantization paramters from input
// tensor =============
auto avg_pool1d = getInputTensorQParamOpFusionInfo(
@@ -806,6 +819,7 @@
{"quantized::mul", inplace_mul, quantized_mul},
{"quantized::hardswish", hardswish, quantized_hardswish},
{"quantized::layer_norm", layer_norm, quantized_layer_norm},
+ {"quantized::group_norm", group_norm, quantized_group_norm},
avg_pool1d,
avg_pool2d,
avg_pool3d,