[ONNX] add supplement for standardOps low precision cast (#60731) (#61561)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61561

address Gary reply and add supplement of https://github.com/pytorch/pytorch/pull/53813.

- add more details for LowPrecisionCastNodeForStandardOps to make it more comprehensible.

- remove unuse gemm test

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb

Differential Revision: D29767991

Pulled By: SplitInfinity

fbshipit-source-id: d00032e13699f5b02fc619e64aa8fdd39f3a66b8

Co-authored-by: hwangdeyu <dejack953@outlook.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6c7de82..00e1568 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2836,33 +2836,6 @@
         y = torch.tensor([2, 3, 5], dtype=torch.float64)
         self.run_test(ModModule(), (x, y))
 
-    @unittest.skip("Gemm operator only support float/double in ONNX")
-    @skipIfUnsupportedMaxOpsetVersion(13)
-    def test_gemm_with_low_precision(self):
-        class GemmModule(torch.nn.Module):
-            def forward(self, x, y):
-                return torch.mm(x, y).to(dtype=torch.long)
-
-            mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
-            mat2 = torch.randn(3, 2).to(dtype=torch.uint8)
-            self.run_test(GemmModule(), (mat1, mat2))
-
-            mat1 = torch.randn(2, 3).to(dtype=torch.int8)
-            mat2 = torch.randn(3, 2).to(dtype=torch.int8)
-            self.run_test(GemmModule(), (mat1, mat2))
-
-            mat1 = torch.randn(2, 3).to(dtype=torch.int16)
-            mat2 = torch.randn(3, 2).to(dtype=torch.int16)
-            self.run_test(GemmModule(), (mat1, mat2))
-
-            mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
-            mat2 = torch.randn(3, 2).to(dtype=torch.int32)
-            self.run_test(GemmModule(), (mat1, mat2))
-
-            mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
-            mat2 = torch.randn(3, 2).to(dtype=torch.float64)
-            self.run_test(GemmModule(), (mat1, mat2))
-
     def test_std(self):
         class StandardDeviation(torch.nn.Module):
             def forward(self, input):
diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
index a6a9eb5..187f64c 100644
--- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
+++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
@@ -227,7 +227,10 @@
 static c10::optional<c10::ScalarType> LowPrecisionCastForStandardOps(
     const Node* n,
     const c10::ScalarType& scalar_type) {
-  // StandardOps do not support uint8\int8\int16 in ONNX now.
+  // Some of standardOps do not support uint8\int8\int16 type for ONNX
+  // opset version < 14.
+  // Fix in this ONNX PR:
+  // https://github.com/onnx/onnx/pull/3334
   if (n->kind() != onnx::Gemm && IsStandardOp(n->kind()) &&
       (scalar_type == c10::kByte || scalar_type == c10::kChar ||
        scalar_type == c10::kShort)) {
@@ -300,6 +303,21 @@
   out->replaceAllUsesAfterNodeWith(cast_node, cast_node->output());
 }
 
+// This example error found when exports transfo_xl model using add op in uint8
+// type, as below:
+// if self.same_length:
+//     all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
+//     mask_len = klen - self.mem_len
+//     if mask_len > 0:
+//         mask_shift_len = qlen - mask_len
+//     else:
+//         mask_shift_len = qlen
+//     dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones,
+//     -mask_shift_len))[:, :, None]  # -1
+//
+// `all_ones is` an uint8 tensor, But the calculation of `dec_attn_mask` using
+// add(+) op to get the uint8 result. Reference Link:
+// https://github.com/huggingface/transformers/blob/b020a736c374460af1b34267283f957988350630/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L936
 static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) {
   TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
   if (n->output()->type()->cast<TensorType>() == nullptr ||