[ONNX] Support dynamic scale & zero_point for fake_quantize_per_tensor_affine

Dynamic scale & zero_point requires opset 13 `ONNX::QuantizeLinear`
and `ONNX::DequantizeLinear`.
Improved error message when scale is not constant for opset 10 symbolic function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75697
Approved by: https://github.com/garymm
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3f61513..6a48e40 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -8599,6 +8599,19 @@
         self.run_test(FakeQuantizePerTensorModel(), (x))
 
     @skipIfUnsupportedMinOpsetVersion(13)
+    def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
+        class FakeQuantizePerTensorModel(torch.nn.Module):
+            def forward(self, input, scale, zero_point):
+                quant_min = -128
+                quant_max = 127
+                return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
+
+        x = torch.randn(6, 4, 3, 3)
+        scale = torch.tensor(1. / 127)
+        zero_point = torch.tensor(0)
+        self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
+
+    @skipIfUnsupportedMinOpsetVersion(13)
     def test_fake_quantize_per_channel(self):
         class FakeQuantizePerChannelModel(torch.nn.Module):
             def forward(self, input):
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
index f444a67..948b214 100644
--- a/torch/onnx/symbolic_opset10.py
+++ b/torch/onnx/symbolic_opset10.py
@@ -298,14 +298,20 @@
                                           "please use opset 11 or higher.")
 
 
-@parse_args("v", "t", "i", "i", "i")
+@parse_args("v", "v", "v", "i", "i")
 def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
-    if quant_min not in [0, -128] or quant_max not in [127, 255]:
+    if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
         raise RuntimeError(
-            "ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
+            "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+            "Got ({}, {})".format(quant_min, quant_max))
+    scale = sym_help._maybe_get_scalar(scale)
+    if scale is None:
+        sym_help._onnx_opset_unsupported_detailed("fake_quantize_per_tensor_affine", 10, 13, "Non-constant scale not supported")
     scale = scale.float().data  # Avoid exporter generating double type
-    zero_point_dtype = torch.int8 if quant_min == -128 else torch.uint8
-    zero_point = torch.tensor(zero_point, dtype=zero_point_dtype)  # ONNX requires zero_point to be tensor
+    if quant_min == 0:
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
+    else:
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
     return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
 
 
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index 61f9793..ed0439e 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -131,20 +131,34 @@
 
 @parse_args("v", "v", "v", "i", "i", "i")
 def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
-    if quant_min not in [0, -128] or quant_max not in [127, 255]:
+    if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
         raise RuntimeError(
-            "ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
-
+            "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+            "Got ({}, {})".format(quant_min, quant_max))
     # ONNX defines zero_point to be int8 or uint8
     if quant_min == 0:
-        zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"])
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
     else:
-        zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx["Char"])
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
     return g.op(
         "DequantizeLinear",
         g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
         scale, zero_point, axis_i=axis)
 
+@parse_args("v", "v", "v", "i", "i")
+def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
+    if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
+        raise RuntimeError(
+            "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
+            "Got ({}, {})".format(quant_min, quant_max))
+    if quant_min == 0:
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
+    else:
+        zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
+    if scale.type().scalarType() != "Float":
+        scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+    return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
+
 def _reduce_op_symbolic(onnx_op_name):
     def symbolic(g, self, dim=None, keepdim=None):
         self = _maybe_cast_reduce_op_input(g, self)