[quant][graphmode] Add quantizedconv1d to graphmode (#38341)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38341

Test Plan:
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_quantized_conv1d

Imported from OSS

Differential Revision: D21554256

fbshipit-source-id: baf78c7788a38acd9362204990f0b22c21263dfb
diff --git a/test/quantization/test_quantize_script.py b/test/quantization/test_quantize_script.py
index 7981b58..872eb02 100644
--- a/test/quantization/test_quantize_script.py
+++ b/test/quantization/test_quantize_script.py
@@ -25,6 +25,7 @@
 
 # Testing utils
 from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
+from torch.testing._internal.common_quantized import override_qengines
 
 from torch.testing import FileCheck
 from torch.testing._internal.jit_utils import attrs_with_prefix
@@ -1100,7 +1101,12 @@
     for individual ops end to end.
     """
     def _test_op_impl(self, module, data, quantized_op):
-        qconfig_dict = {'': get_default_qconfig('fbgemm')}
+        qengine = torch.backends.quantized.engine
+        if qengine == 'none':
+            qconfig = default_qconfig
+        else:
+            qconfig = get_default_qconfig(qengine)
+        qconfig_dict = {'': qconfig}
         model = torch.jit.script(module).eval()
         model = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data], inplace=False)
         FileCheck().check(quantized_op) \
@@ -1112,9 +1118,28 @@
 
         return model
 
+    @override_qengines
+    def test_quantized_conv1d(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                self.conv = torch.nn.Conv1d(3, 3, 3).float()
+
+            def forward(self, x):
+                return self.conv(x)
+
+        data = [(torch.rand((1, 3, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
+        model = self._test_op_impl(M(), data, "quantized::conv1d")
+        # make sure there is only one quantize_per_tensor for input
+        # and conv2d_prepack is folded
+        FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
+                   .run(model.graph)
+
+        FileCheck().check_not("quantized::conv1d_prepack") \
+                   .run(model.graph)
+
     @unittest.skipUnless(
         'fbgemm' in torch.backends.quantized.supported_engines,
-        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
         " with instruction set support avx2 or newer.",
     )
     def test_quantized_conv2d(self):
diff --git a/torch/csrc/jit/passes/quantization/finalize.cpp b/torch/csrc/jit/passes/quantization/finalize.cpp
index 5963c3e..afa48cb 100644
--- a/torch/csrc/jit/passes/quantization/finalize.cpp
+++ b/torch/csrc/jit/passes/quantization/finalize.cpp
@@ -29,6 +29,20 @@
 }
 
 void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) {
+  std::string conv1d_with_quant = R"(
+graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
+        %w_dequant = aten::dequantize(%w_quant)
+        %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
+        return (%r) )";
+
+  std::string conv1d_with_quant_prepack = R"(
+graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
+        %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
+        %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params)
+        %w_dequant = aten::dequantize(%w_quant_unpacked)
+        %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
+        return (%r) )";
+
   std::string conv2d_with_quant = R"(
 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
         %w_dequant = aten::dequantize(%w_quant)
@@ -58,6 +72,7 @@
         return (%r) )";
 
   std::vector<std::vector<std::string>> patterns_and_replacements = {
+      {conv1d_with_quant, conv1d_with_quant_prepack},
       {conv2d_with_quant, conv2d_with_quant_prepack},
       {conv3d_with_quant, conv3d_with_quant_prepack}};
   for (const auto& item : patterns_and_replacements) {
@@ -104,6 +119,7 @@
   auto filter_fn = [](const Node* n) -> bool {
     return (
         (n->kind() == Symbol::fromQualString("quantized::linear_prepack")) ||
+        n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
         n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
         n->kind() == Symbol::fromQualString("quantized::conv3d_prepack"));
   };
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 6c3a22f..bee8915 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -24,6 +24,7 @@
 };
 
 std::vector<std::string> _static_quantizable_aten_funcs = {
+    "conv1d",
     "conv2d",
     "conv3d",
     "linear",
@@ -201,7 +202,11 @@
 bool isWeight(Value* v) {
   bool result = matchArgPattern(
       v,
-      AtenFuncArgs({{"conv2d", 1}, {"conv3d", 1}, {"linear", 1}, {"lstm", 2}}),
+      AtenFuncArgs({{"conv1d", 1},
+                    {"conv2d", 1},
+                    {"conv3d", 1},
+                    {"linear", 1},
+                    {"lstm", 2}}),
       CallFuncArgs({{"linear", 2}}));
   return result;
 }
@@ -209,7 +214,8 @@
 bool isBiasOfConvOrLinear(Value* v) {
   bool result = matchArgPattern(
       v,
-      AtenFuncArgs({{"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}),
+      AtenFuncArgs(
+          {{"conv1d", 2}, {"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}),
       CallFuncArgs({{"linear", 3}}));
   return result;
 }
diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h
index c599bcf..1174b67 100644
--- a/torch/csrc/jit/passes/quantization/quantization_patterns.h
+++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h
@@ -22,6 +22,16 @@
 };
 
 std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
+  // aten::conv1d
+  std::string conv1d = R"(
+graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
+        %a_dequant = aten::dequantize(%a_quant)
+        %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
+        %w_dequant = aten::dequantize(%w_quant)
+        %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
+        %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
+        return (%r_quant) )";
+
   // aten::conv2d
   std::string conv2d = R"(
 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
@@ -54,6 +64,12 @@
         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
         return (%r_quant) )";
 
+  // quantized::conv1d
+  std::string quantized_conv1d = R"(
+graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
+        %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
+        return (%r_quant) )";
+
   // quantized::conv2d
   std::string quantized_conv2d = R"(
 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
@@ -533,6 +549,7 @@
           return (%r) )";
 
   return {
+      {"quantized::conv1d", conv1d, quantized_conv1d},
       {"quantized::conv2d", conv2d, quantized_conv2d},
       {"quantized::conv2d_relu", conv2d_relu, quantized_conv2d_relu},
       {"quantized::conv2d_relu", conv2d_inplace_relu, quantized_conv2d_relu},