[quant][graphmode][fx] Add support for general value ops (#43439)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43439

Porting op tests from test_quantize_jit.py

Test Plan:
TestQuantizeFxOps

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D23278585

fbshipit-source-id: ad29f39482cf4909068ce29555470ef430ea17f6
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index d86a0fd..6729d5e 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -535,7 +535,7 @@
         for quant_type in self.static_quant_types:
             m = self.checkGraphModeFxOp(M(), data, checks, quant_type)
 
-
+    @skipIfNoFBGEMM
     def test_general_shape_ops(self):
         """ A test that checks dequantize will be swapped for
         all supported general shape ops like aten::flatten
@@ -652,3 +652,95 @@
         }
         for check in (order_check, count_check):
             self.checkGraphModule(quantized, check)
+
+    @skipIfNoFBGEMM
+    def test_general_value_ops(self):
+        """ A test that checks correct patterns are produced for
+        all supported general value ops like aten::avg_pool2d \
+        without actually checking for execution of these ops
+        """
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                self.conv = torch.nn.Conv2d(3, 3, 3)
+                self.avg_pool1d = torch.nn.AvgPool1d(3)
+                self.avg_pool2d = torch.nn.AvgPool2d(3)
+                self.avg_pool3d = torch.nn.AvgPool3d(3)
+                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
+                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
+                self.leaky_relu = torch.nn.LeakyReLU()
+                self.hardsigmoid = torch.nn.Hardsigmoid()
+                self.sigmoid = torch.nn.Sigmoid()
+                self.tanh = torch.nn.Tanh()
+
+            def forward(self, x):
+                x = self.conv(x)
+                x = self.avg_pool1d(x)
+                x = self.avg_pool2d(x)
+                x = self.avg_pool3d(x)
+                x = self.adaptive_avg_pool1d(x)
+                x = self.adaptive_avg_pool2d(x)
+                x = self.adaptive_avg_pool3d(x)
+                x = F.avg_pool1d(x, 3)
+                x = F.avg_pool2d(x, 3)
+                x = F.avg_pool3d(x, 3)
+                x = F.adaptive_avg_pool1d(x, (1))
+                x = F.adaptive_avg_pool2d(x, (1, 1))
+                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
+                x = torch.mean(x)
+                x = torch.mean(x, [2, 3], False)
+                x = x.mean()
+                x = x.mean([2, 3], True)
+                x = F.interpolate(x, 4, mode='nearest')
+                x = F.interpolate(x, 4, mode='linear')
+                x = self.leaky_relu(x)
+                x = F.leaky_relu(x)
+                x = F.leaky_relu(x, inplace=True)
+                x = x.leaky_relu()
+                x.leaky_relu_()
+                x = self.hardsigmoid(x)
+                x = F.hardsigmoid(x)
+                x = F.hardsigmoid(x, inplace=True)
+                x = x.hardsigmoid()
+                x.hardsigmoid_()
+                x = self.sigmoid(x)
+                x = torch.sigmoid(x)
+                # F.sigmoid is deprecated
+                x = x.sigmoid()
+                x.sigmoid_()
+                x = self.tanh(x)
+                # F.tanh is deprecated
+                x = torch.tanh(x)
+                x = x.tanh()
+                x.tanh_()
+                x = self.conv(x)
+                return x
+
+        # This model is not executable since we just put all ops
+        # in the same forward
+        m = M()
+        original = symbolic_trace(m)
+        # nothing to fuse so skipping the fuse step
+        quantizer = Quantizer()
+        qconfig_dict = {'': default_qconfig}
+        prepared = quantizer.prepare(original, qconfig_dict)
+        # not runnable
+        quantized = quantizer.convert(prepared)
+
+        # This checks that the dequantize from the output of first conv
+        # is being propagated to the end, so that we don't insert extra
+        # observers
+        order_check = [
+            ('call_function', torch.quantize_per_tensor),
+            ('call_module', nnq.Conv2d),
+            ('call_module', nnq.Conv2d),
+            ('call_method', 'dequantize'),
+        ]
+        # check exact counts of quantize and dequantize
+        count_check = {
+            ('call_function', torch.quantize_per_tensor) : 1,
+            ('call_method', 'dequantize') : 1
+        }
+        for check in (order_check, count_check):
+            self.checkGraphModule(quantized, check)
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 48f369a..447257e 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -451,23 +451,40 @@
             'call_function', quantized_op, args, kwargs)
 
 # these ops have quantized equivalents that do not need any extra information
+@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
 @register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
+@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
+@register_quant_pattern(torch.nn.AvgPool1d)
 @register_quant_pattern(torch.nn.AvgPool2d)
+@register_quant_pattern(torch.nn.AvgPool3d)
 @register_quant_pattern(torch.nn.Dropout)
+@register_quant_pattern(torch.nn.Hardsigmoid)
 @register_quant_pattern(torch.nn.Hardtanh)
+@register_quant_pattern(torch.nn.LeakyReLU)
 @register_quant_pattern(torch.nn.MaxPool1d)
 @register_quant_pattern(torch.nn.MaxPool2d)
 @register_quant_pattern(torch.nn.MaxPool3d)
 @register_quant_pattern(torch.nn.ReLU)
 @register_quant_pattern(torch.nn.ReLU6)
+@register_quant_pattern(torch.nn.Sigmoid)
+@register_quant_pattern(torch.nn.Tanh)
+@register_quant_pattern(torch.adaptive_avg_pool1d)
 @register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
+@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
 @register_quant_pattern(torch.nn.functional.dropout)
+@register_quant_pattern(torch.nn.functional.hardsigmoid)
 @register_quant_pattern(torch.nn.functional.hardtanh)
 @register_quant_pattern(torch.nn.functional.hardtanh_)
+@register_quant_pattern(torch.nn.functional.interpolate)
+@register_quant_pattern(torch.nn.functional.leaky_relu)
+@register_quant_pattern(torch.nn.functional.max_pool1d)
 @register_quant_pattern(torch.nn.functional.max_pool2d)
+@register_quant_pattern(torch.nn.functional.max_pool3d)
 @register_quant_pattern(torch.nn.functional.relu)
 @register_quant_pattern(torch.nn.functional.relu6)
+@register_quant_pattern(torch.avg_pool1d)
 @register_quant_pattern(torch._C._nn.avg_pool2d)
+@register_quant_pattern(torch._C._nn.avg_pool3d)
 @register_quant_pattern(torch.chunk)
 @register_quant_pattern(torch.clamp)
 @register_quant_pattern(torch.flatten)
@@ -476,9 +493,11 @@
 @register_quant_pattern(torch.mean)
 @register_quant_pattern(torch.min)
 @register_quant_pattern(torch.repeat_interleave)
+@register_quant_pattern(torch.sigmoid)
 @register_quant_pattern(torch.sort)
 @register_quant_pattern(torch.squeeze)
 @register_quant_pattern(torch.stack)
+@register_quant_pattern(torch.tanh)
 @register_quant_pattern(torch.unsqueeze)
 @register_quant_pattern(operator.getitem)
 @register_quant_pattern(operator.floordiv)
@@ -487,6 +506,10 @@
 @register_quant_pattern('contiguous')
 @register_quant_pattern('detach')
 @register_quant_pattern('detach_')
+@register_quant_pattern('hardsigmoid')
+@register_quant_pattern('hardsigmoid_')
+@register_quant_pattern('leaky_relu')
+@register_quant_pattern('leaky_relu_')
 @register_quant_pattern('mean')
 @register_quant_pattern('numel')
 @register_quant_pattern('permute')
@@ -497,12 +520,16 @@
 @register_quant_pattern('reshape')
 @register_quant_pattern('resize_')
 @register_quant_pattern('shape')
+@register_quant_pattern('sigmoid')
+@register_quant_pattern('sigmoid_')
 @register_quant_pattern('size')
 @register_quant_pattern('squeeze')
 @register_quant_pattern('squeeze_')
+@register_quant_pattern('tanh')
+@register_quant_pattern('tanh_')
+@register_quant_pattern('transpose')
 @register_quant_pattern('unsqueeze')
 @register_quant_pattern('unsqueeze_')
-@register_quant_pattern('transpose')
 @register_quant_pattern('view')
 class CopyNode(QuantizeHandler):
     def convert(self, quantizer, node, load_arg, debug=False):
@@ -784,8 +811,11 @@
                 elif node.name in env:
                     return False
             elif isinstance(node, list):
-                if all(map(is_quantized, node)):
+                quantized = map(is_quantized, node)
+                if all(quantized):
                     return True
+                elif not any(quantized):
+                    return False
                 else:
                     raise Exception("partially quantized inputs in list not handled yet")