[ONNX] add supplement for standardOps low precision cast (#60731) (#61561)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61561
address Gary reply and add supplement of https://github.com/pytorch/pytorch/pull/53813.
- add more details for LowPrecisionCastNodeForStandardOps to make it more comprehensible.
- remove unuse gemm test
Test Plan: Imported from OSS
Reviewed By: nikithamalgifb
Differential Revision: D29767991
Pulled By: SplitInfinity
fbshipit-source-id: d00032e13699f5b02fc619e64aa8fdd39f3a66b8
Co-authored-by: hwangdeyu <dejack953@outlook.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6c7de82..00e1568 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2836,33 +2836,6 @@
y = torch.tensor([2, 3, 5], dtype=torch.float64)
self.run_test(ModModule(), (x, y))
- @unittest.skip("Gemm operator only support float/double in ONNX")
- @skipIfUnsupportedMaxOpsetVersion(13)
- def test_gemm_with_low_precision(self):
- class GemmModule(torch.nn.Module):
- def forward(self, x, y):
- return torch.mm(x, y).to(dtype=torch.long)
-
- mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
- mat2 = torch.randn(3, 2).to(dtype=torch.uint8)
- self.run_test(GemmModule(), (mat1, mat2))
-
- mat1 = torch.randn(2, 3).to(dtype=torch.int8)
- mat2 = torch.randn(3, 2).to(dtype=torch.int8)
- self.run_test(GemmModule(), (mat1, mat2))
-
- mat1 = torch.randn(2, 3).to(dtype=torch.int16)
- mat2 = torch.randn(3, 2).to(dtype=torch.int16)
- self.run_test(GemmModule(), (mat1, mat2))
-
- mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
- mat2 = torch.randn(3, 2).to(dtype=torch.int32)
- self.run_test(GemmModule(), (mat1, mat2))
-
- mat1 = torch.randn(2, 3).to(dtype=torch.uint8)
- mat2 = torch.randn(3, 2).to(dtype=torch.float64)
- self.run_test(GemmModule(), (mat1, mat2))
-
def test_std(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
index a6a9eb5..187f64c 100644
--- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
+++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
@@ -227,7 +227,10 @@
static c10::optional<c10::ScalarType> LowPrecisionCastForStandardOps(
const Node* n,
const c10::ScalarType& scalar_type) {
- // StandardOps do not support uint8\int8\int16 in ONNX now.
+ // Some of standardOps do not support uint8\int8\int16 type for ONNX
+ // opset version < 14.
+ // Fix in this ONNX PR:
+ // https://github.com/onnx/onnx/pull/3334
if (n->kind() != onnx::Gemm && IsStandardOp(n->kind()) &&
(scalar_type == c10::kByte || scalar_type == c10::kChar ||
scalar_type == c10::kShort)) {
@@ -300,6 +303,21 @@
out->replaceAllUsesAfterNodeWith(cast_node, cast_node->output());
}
+// This example error found when exports transfo_xl model using add op in uint8
+// type, as below:
+// if self.same_length:
+// all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
+// mask_len = klen - self.mem_len
+// if mask_len > 0:
+// mask_shift_len = qlen - mask_len
+// else:
+// mask_shift_len = qlen
+// dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones,
+// -mask_shift_len))[:, :, None] # -1
+//
+// `all_ones is` an uint8 tensor, But the calculation of `dec_attn_mask` using
+// add(+) op to get the uint8 result. Reference Link:
+// https://github.com/huggingface/transformers/blob/b020a736c374460af1b34267283f957988350630/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L936
static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) {
TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
if (n->output()->type()->cast<TensorType>() == nullptr ||