[quant][graphmode] Add quantizedconv1d to graphmode (#38341)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38341
Test Plan:
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_quantized_conv1d
Imported from OSS
Differential Revision: D21554256
fbshipit-source-id: baf78c7788a38acd9362204990f0b22c21263dfb
diff --git a/test/quantization/test_quantize_script.py b/test/quantization/test_quantize_script.py
index 7981b58..872eb02 100644
--- a/test/quantization/test_quantize_script.py
+++ b/test/quantization/test_quantize_script.py
@@ -25,6 +25,7 @@
# Testing utils
from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
+from torch.testing._internal.common_quantized import override_qengines
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import attrs_with_prefix
@@ -1100,7 +1101,12 @@
for individual ops end to end.
"""
def _test_op_impl(self, module, data, quantized_op):
- qconfig_dict = {'': get_default_qconfig('fbgemm')}
+ qengine = torch.backends.quantized.engine
+ if qengine == 'none':
+ qconfig = default_qconfig
+ else:
+ qconfig = get_default_qconfig(qengine)
+ qconfig_dict = {'': qconfig}
model = torch.jit.script(module).eval()
model = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data], inplace=False)
FileCheck().check(quantized_op) \
@@ -1112,9 +1118,28 @@
return model
+ @override_qengines
+ def test_quantized_conv1d(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ self.conv = torch.nn.Conv1d(3, 3, 3).float()
+
+ def forward(self, x):
+ return self.conv(x)
+
+ data = [(torch.rand((1, 3, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
+ model = self._test_op_impl(M(), data, "quantized::conv1d")
+ # make sure there is only one quantize_per_tensor for input
+ # and conv2d_prepack is folded
+ FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
+ .run(model.graph)
+
+ FileCheck().check_not("quantized::conv1d_prepack") \
+ .run(model.graph)
+
@unittest.skipUnless(
'fbgemm' in torch.backends.quantized.supported_engines,
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
def test_quantized_conv2d(self):
diff --git a/torch/csrc/jit/passes/quantization/finalize.cpp b/torch/csrc/jit/passes/quantization/finalize.cpp
index 5963c3e..afa48cb 100644
--- a/torch/csrc/jit/passes/quantization/finalize.cpp
+++ b/torch/csrc/jit/passes/quantization/finalize.cpp
@@ -29,6 +29,20 @@
}
void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) {
+ std::string conv1d_with_quant = R"(
+graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
+ %w_dequant = aten::dequantize(%w_quant)
+ %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
+ return (%r) )";
+
+ std::string conv1d_with_quant_prepack = R"(
+graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
+ %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
+ %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params)
+ %w_dequant = aten::dequantize(%w_quant_unpacked)
+ %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
+ return (%r) )";
+
std::string conv2d_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
%w_dequant = aten::dequantize(%w_quant)
@@ -58,6 +72,7 @@
return (%r) )";
std::vector<std::vector<std::string>> patterns_and_replacements = {
+ {conv1d_with_quant, conv1d_with_quant_prepack},
{conv2d_with_quant, conv2d_with_quant_prepack},
{conv3d_with_quant, conv3d_with_quant_prepack}};
for (const auto& item : patterns_and_replacements) {
@@ -104,6 +119,7 @@
auto filter_fn = [](const Node* n) -> bool {
return (
(n->kind() == Symbol::fromQualString("quantized::linear_prepack")) ||
+ n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
n->kind() == Symbol::fromQualString("quantized::conv3d_prepack"));
};
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 6c3a22f..bee8915 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -24,6 +24,7 @@
};
std::vector<std::string> _static_quantizable_aten_funcs = {
+ "conv1d",
"conv2d",
"conv3d",
"linear",
@@ -201,7 +202,11 @@
bool isWeight(Value* v) {
bool result = matchArgPattern(
v,
- AtenFuncArgs({{"conv2d", 1}, {"conv3d", 1}, {"linear", 1}, {"lstm", 2}}),
+ AtenFuncArgs({{"conv1d", 1},
+ {"conv2d", 1},
+ {"conv3d", 1},
+ {"linear", 1},
+ {"lstm", 2}}),
CallFuncArgs({{"linear", 2}}));
return result;
}
@@ -209,7 +214,8 @@
bool isBiasOfConvOrLinear(Value* v) {
bool result = matchArgPattern(
v,
- AtenFuncArgs({{"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}),
+ AtenFuncArgs(
+ {{"conv1d", 2}, {"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}),
CallFuncArgs({{"linear", 3}}));
return result;
}
diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h
index c599bcf..1174b67 100644
--- a/torch/csrc/jit/passes/quantization/quantization_patterns.h
+++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h
@@ -22,6 +22,16 @@
};
std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
+ // aten::conv1d
+ std::string conv1d = R"(
+graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
+ %a_dequant = aten::dequantize(%a_quant)
+ %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
+ %w_dequant = aten::dequantize(%w_quant)
+ %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
+ %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
+ return (%r_quant) )";
+
// aten::conv2d
std::string conv2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
@@ -54,6 +64,12 @@
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
return (%r_quant) )";
+ // quantized::conv1d
+ std::string quantized_conv1d = R"(
+graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
+ %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
+ return (%r_quant) )";
+
// quantized::conv2d
std::string quantized_conv2d = R"(
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
@@ -533,6 +549,7 @@
return (%r) )";
return {
+ {"quantized::conv1d", conv1d, quantized_conv1d},
{"quantized::conv2d", conv2d, quantized_conv2d},
{"quantized::conv2d_relu", conv2d_relu, quantized_conv2d_relu},
{"quantized::conv2d_relu", conv2d_inplace_relu, quantized_conv2d_relu},