[quant][graphmode][fx] Add support for general value ops (#43439)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43439
Porting op tests from test_quantize_jit.py
Test Plan:
TestQuantizeFxOps
Imported from OSS
Reviewed By: raghuramank100
Differential Revision: D23278585
fbshipit-source-id: ad29f39482cf4909068ce29555470ef430ea17f6
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index d86a0fd..6729d5e 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -535,7 +535,7 @@
for quant_type in self.static_quant_types:
m = self.checkGraphModeFxOp(M(), data, checks, quant_type)
-
+ @skipIfNoFBGEMM
def test_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
@@ -652,3 +652,95 @@
}
for check in (order_check, count_check):
self.checkGraphModule(quantized, check)
+
+ @skipIfNoFBGEMM
+ def test_general_value_ops(self):
+ """ A test that checks correct patterns are produced for
+ all supported general value ops like aten::avg_pool2d \
+ without actually checking for execution of these ops
+ """
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ self.conv = torch.nn.Conv2d(3, 3, 3)
+ self.avg_pool1d = torch.nn.AvgPool1d(3)
+ self.avg_pool2d = torch.nn.AvgPool2d(3)
+ self.avg_pool3d = torch.nn.AvgPool3d(3)
+ self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
+ self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+ self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
+ self.leaky_relu = torch.nn.LeakyReLU()
+ self.hardsigmoid = torch.nn.Hardsigmoid()
+ self.sigmoid = torch.nn.Sigmoid()
+ self.tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.avg_pool1d(x)
+ x = self.avg_pool2d(x)
+ x = self.avg_pool3d(x)
+ x = self.adaptive_avg_pool1d(x)
+ x = self.adaptive_avg_pool2d(x)
+ x = self.adaptive_avg_pool3d(x)
+ x = F.avg_pool1d(x, 3)
+ x = F.avg_pool2d(x, 3)
+ x = F.avg_pool3d(x, 3)
+ x = F.adaptive_avg_pool1d(x, (1))
+ x = F.adaptive_avg_pool2d(x, (1, 1))
+ x = F.adaptive_avg_pool3d(x, (1, 1, 1))
+ x = torch.mean(x)
+ x = torch.mean(x, [2, 3], False)
+ x = x.mean()
+ x = x.mean([2, 3], True)
+ x = F.interpolate(x, 4, mode='nearest')
+ x = F.interpolate(x, 4, mode='linear')
+ x = self.leaky_relu(x)
+ x = F.leaky_relu(x)
+ x = F.leaky_relu(x, inplace=True)
+ x = x.leaky_relu()
+ x.leaky_relu_()
+ x = self.hardsigmoid(x)
+ x = F.hardsigmoid(x)
+ x = F.hardsigmoid(x, inplace=True)
+ x = x.hardsigmoid()
+ x.hardsigmoid_()
+ x = self.sigmoid(x)
+ x = torch.sigmoid(x)
+ # F.sigmoid is deprecated
+ x = x.sigmoid()
+ x.sigmoid_()
+ x = self.tanh(x)
+ # F.tanh is deprecated
+ x = torch.tanh(x)
+ x = x.tanh()
+ x.tanh_()
+ x = self.conv(x)
+ return x
+
+ # This model is not executable since we just put all ops
+ # in the same forward
+ m = M()
+ original = symbolic_trace(m)
+ # nothing to fuse so skipping the fuse step
+ quantizer = Quantizer()
+ qconfig_dict = {'': default_qconfig}
+ prepared = quantizer.prepare(original, qconfig_dict)
+ # not runnable
+ quantized = quantizer.convert(prepared)
+
+ # This checks that the dequantize from the output of first conv
+ # is being propagated to the end, so that we don't insert extra
+ # observers
+ order_check = [
+ ('call_function', torch.quantize_per_tensor),
+ ('call_module', nnq.Conv2d),
+ ('call_module', nnq.Conv2d),
+ ('call_method', 'dequantize'),
+ ]
+ # check exact counts of quantize and dequantize
+ count_check = {
+ ('call_function', torch.quantize_per_tensor) : 1,
+ ('call_method', 'dequantize') : 1
+ }
+ for check in (order_check, count_check):
+ self.checkGraphModule(quantized, check)
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 48f369a..447257e 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -451,23 +451,40 @@
'call_function', quantized_op, args, kwargs)
# these ops have quantized equivalents that do not need any extra information
+@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
+@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
+@register_quant_pattern(torch.nn.AvgPool1d)
@register_quant_pattern(torch.nn.AvgPool2d)
+@register_quant_pattern(torch.nn.AvgPool3d)
@register_quant_pattern(torch.nn.Dropout)
+@register_quant_pattern(torch.nn.Hardsigmoid)
@register_quant_pattern(torch.nn.Hardtanh)
+@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
+@register_quant_pattern(torch.nn.Sigmoid)
+@register_quant_pattern(torch.nn.Tanh)
+@register_quant_pattern(torch.adaptive_avg_pool1d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
+@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
@register_quant_pattern(torch.nn.functional.dropout)
+@register_quant_pattern(torch.nn.functional.hardsigmoid)
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
+@register_quant_pattern(torch.nn.functional.interpolate)
+@register_quant_pattern(torch.nn.functional.leaky_relu)
+@register_quant_pattern(torch.nn.functional.max_pool1d)
@register_quant_pattern(torch.nn.functional.max_pool2d)
+@register_quant_pattern(torch.nn.functional.max_pool3d)
@register_quant_pattern(torch.nn.functional.relu)
@register_quant_pattern(torch.nn.functional.relu6)
+@register_quant_pattern(torch.avg_pool1d)
@register_quant_pattern(torch._C._nn.avg_pool2d)
+@register_quant_pattern(torch._C._nn.avg_pool3d)
@register_quant_pattern(torch.chunk)
@register_quant_pattern(torch.clamp)
@register_quant_pattern(torch.flatten)
@@ -476,9 +493,11 @@
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.min)
@register_quant_pattern(torch.repeat_interleave)
+@register_quant_pattern(torch.sigmoid)
@register_quant_pattern(torch.sort)
@register_quant_pattern(torch.squeeze)
@register_quant_pattern(torch.stack)
+@register_quant_pattern(torch.tanh)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
@@ -487,6 +506,10 @@
@register_quant_pattern('contiguous')
@register_quant_pattern('detach')
@register_quant_pattern('detach_')
+@register_quant_pattern('hardsigmoid')
+@register_quant_pattern('hardsigmoid_')
+@register_quant_pattern('leaky_relu')
+@register_quant_pattern('leaky_relu_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
@@ -497,12 +520,16 @@
@register_quant_pattern('reshape')
@register_quant_pattern('resize_')
@register_quant_pattern('shape')
+@register_quant_pattern('sigmoid')
+@register_quant_pattern('sigmoid_')
@register_quant_pattern('size')
@register_quant_pattern('squeeze')
@register_quant_pattern('squeeze_')
+@register_quant_pattern('tanh')
+@register_quant_pattern('tanh_')
+@register_quant_pattern('transpose')
@register_quant_pattern('unsqueeze')
@register_quant_pattern('unsqueeze_')
-@register_quant_pattern('transpose')
@register_quant_pattern('view')
class CopyNode(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
@@ -784,8 +811,11 @@
elif node.name in env:
return False
elif isinstance(node, list):
- if all(map(is_quantized, node)):
+ quantized = map(is_quantized, node)
+ if all(quantized):
return True
+ elif not any(quantized):
+ return False
else:
raise Exception("partially quantized inputs in list not handled yet")