[ONNX] Support dynamic scale & zero_point for fake_quantize_per_tensor_affine
Dynamic scale & zero_point requires opset 13 `ONNX::QuantizeLinear`
and `ONNX::DequantizeLinear`.
Improved error message when scale is not constant for opset 10 symbolic function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75697
Approved by: https://github.com/garymm
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3f61513..6a48e40 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -8599,6 +8599,19 @@
self.run_test(FakeQuantizePerTensorModel(), (x))
@skipIfUnsupportedMinOpsetVersion(13)
+ def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
+ class FakeQuantizePerTensorModel(torch.nn.Module):
+ def forward(self, input, scale, zero_point):
+ quant_min = -128
+ quant_max = 127
+ return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
+
+ x = torch.randn(6, 4, 3, 3)
+ scale = torch.tensor(1. / 127)
+ zero_point = torch.tensor(0)
+ self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
+
+ @skipIfUnsupportedMinOpsetVersion(13)
def test_fake_quantize_per_channel(self):
class FakeQuantizePerChannelModel(torch.nn.Module):
def forward(self, input):
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
index f444a67..948b214 100644
--- a/torch/onnx/symbolic_opset10.py
+++ b/torch/onnx/symbolic_opset10.py
@@ -298,14 +298,20 @@
"please use opset 11 or higher.")
-@parse_args("v", "t", "i", "i", "i")
+@parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
- if quant_min not in [0, -128] or quant_max not in [127, 255]:
+ if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise RuntimeError(
- "ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
+ "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+ "Got ({}, {})".format(quant_min, quant_max))
+ scale = sym_help._maybe_get_scalar(scale)
+ if scale is None:
+ sym_help._onnx_opset_unsupported_detailed("fake_quantize_per_tensor_affine", 10, 13, "Non-constant scale not supported")
scale = scale.float().data # Avoid exporter generating double type
- zero_point_dtype = torch.int8 if quant_min == -128 else torch.uint8
- zero_point = torch.tensor(zero_point, dtype=zero_point_dtype) # ONNX requires zero_point to be tensor
+ if quant_min == 0:
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
+ else:
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index 61f9793..ed0439e 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -131,20 +131,34 @@
@parse_args("v", "v", "v", "i", "i", "i")
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
- if quant_min not in [0, -128] or quant_max not in [127, 255]:
+ if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise RuntimeError(
- "ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
-
+ "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+ "Got ({}, {})".format(quant_min, quant_max))
# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
- zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"])
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
else:
- zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx["Char"])
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
scale, zero_point, axis_i=axis)
+@parse_args("v", "v", "v", "i", "i")
+def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
+ if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
+ raise RuntimeError(
+ "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+ "Got ({}, {})".format(quant_min, quant_max))
+ if quant_min == 0:
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
+ else:
+ zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
+ if scale.type().scalarType() != "Float":
+ scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+ return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
+
def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)