[ONNX] Add dtype check in onnx verification (#79263)

Currently we don't have a dtype check in verifying the consistency between PyTorch and ONNX outputs. As a result, some of dtype inconsistencies were found and reported: #77842 #77845

This is a POC.

Failed workflows:
- [linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)]
  - inconsistent shape
    - TestONNXRuntime_opset10.test_all (#79371)
    - TestONNXRuntime_opset10.test_any (#79371)
    - TestONNXRuntime_opset10.test_argmin_argmax (#79503)
    - TestONNXRuntime_opset10.test_hardshrink (#79695)
    - TestONNXRuntime_opset10.test_linalg_norm (#79506)
    - TestONNXRuntime_opset10.test_linalg_vector_norm (#79506)
    - TestONNXRuntime_opset10.test_prelu_scalar (#79846)
    - TestONNXRuntime_opset10.test_softshrink (#79695)
    - TestONNXRuntime_opset10.test_sum_empty_tensor (skipped)
    - TestONNXRuntime_opset10.test_tolist (skipped)
  - inconsistent dtype
    - test_arithmetic_prim_bool (skipped)
    - test_arithmeticOps_with_low_precision (skipped)
    - test_arithmetic_prim_float (skipped)
    - test_logical_and (#79339)
    - test_logical_or (#79339)
    - test_logical_xor (#79339)
    - test_pow (skipped)
    - test_primitive_input_floating (skipped)
    - test_quantize_per_tensor (#79690)
    - test_quantized_adaptive_avg_pool2d (#79690)
    - test_quantized_arithmetic (#79690)
    - test_quantized_arithmetic_qfunctional (#79690)
    - test_quantized_conv2d (#79690)
    - test_quantized_conv2d_relu (#79690)
    - test_quantized_flatten (#79690)
    - test_quantized_hardsigmoid (#79690)
    - test_quantized_hardswish (#79690)
    - test_quantized_linear (#79690)
    - test_quantized_sigmoid (#79690)
    - test_item (skipped)
    - test_full_like_value (skipped)
    - TestONNXRuntime_opset7.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset8.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9_IRv4.test_div_rounding_mode (skipped)
    - test_outer (skipped)
    - test_symbolic_shape_inference_arange_2 (skipped)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79263
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index 9c617c3..fe0614f 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -37,6 +37,10 @@
     kwargs["ort_providers"] = _ORT_PROVIDERS
     kwargs["opset_version"] = test_suite.opset_version
     kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
+    if hasattr(test_suite, "check_shape"):
+        kwargs["check_shape"] = test_suite.check_shape
+    if hasattr(test_suite, "check_dtype"):
+        kwargs["check_dtype"] = test_suite.check_dtype
     return verification.verify(*args, **kwargs)
 
 
@@ -60,6 +64,8 @@
     opset_version = _constants.onnx_default_opset
     keep_initializers_as_inputs = True  # For IR version 3 type export.
     is_script = False
+    check_shape = True
+    check_dtype = True
 
     def setUp(self):
         set_rng_seed(0)
diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py
index 77bdd28..67b3905 100644
--- a/test/onnx/pytorch_test_common.py
+++ b/test/onnx/pytorch_test_common.py
@@ -139,5 +139,23 @@
     return skip_dec
 
 
+def skipShapeChecking(func):
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        self.check_shape = False
+        return func(self, *args, **kwargs)
+
+    return wrapper
+
+
+def skipDtypeChecking(func):
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        self.check_dtype = False
+        return func(self, *args, **kwargs)
+
+    return wrapper
+
+
 def flatten(x):
     return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py
index a850779..4794fb5 100644
--- a/test/onnx/test_pytorch_jit_onnx.py
+++ b/test/onnx/test_pytorch_jit_onnx.py
@@ -50,6 +50,8 @@
 
     opset_version = -1  # Sub-classes must override
     ort_providers = ["CPUExecutionProvider"]
+    check_shape = True
+    check_dtype = True
 
     def run_test(self, graph_ir, example_inputs):
         graph = torch._C.parse_ir(graph_ir)
@@ -64,7 +66,12 @@
         ort_outs = verification._run_ort(ort_sess, example_inputs)
 
         verification._compare_ort_pytorch_outputs(
-            ort_outs, jit_outs, rtol=1e-3, atol=1e-7
+            ort_outs,
+            jit_outs,
+            rtol=1e-3,
+            atol=1e-7,
+            check_shape=self.check_shape,
+            check_dtype=self.check_dtype,
         )
 
     def test_example_ir(self):
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 2905345..0159248 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -26,11 +26,13 @@
     RNN_HIDDEN_SIZE,
     RNN_INPUT_SIZE,
     RNN_SEQUENCE_LENGTH,
+    skipDtypeChecking,
     skipForAllOpsetVersions,
     skipIfUnsupportedMaxOpsetVersion,
     skipIfUnsupportedMinOpsetVersion,
     skipIfUnsupportedOpsetVersion,
     skipScriptTest,
+    skipShapeChecking,
     skipTraceTest,
 )
 from torch import Tensor
@@ -827,6 +829,7 @@
         y = torch.randint(10, (2, 3, 4))
         self.run_test(Model(), (x, y))
 
+    @skipDtypeChecking
     def test_primitive_input_floating(self):
         class Model(torch.nn.Module):
             def __init__(self):
@@ -1531,6 +1534,7 @@
         x = torch.randn(2, 3, 4)
         self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
 
+    @skipDtypeChecking
     def test_arithmetic_prim_float(self):
         class ArithmeticModule(torch.nn.Module):
             def forward(self, x, y: float):
@@ -1553,6 +1557,7 @@
         x = torch.randn(2, 3, 4)
         self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
 
+    @skipDtypeChecking
     def test_arithmetic_prim_bool(self):
         class ArithmeticModule(torch.nn.Module):
             def forward(self, x, y: int, z: bool, t: float):
@@ -1720,6 +1725,7 @@
         y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
         self.run_test(torch.jit.script(DivModule()), (x, y))
 
+    @skipDtypeChecking
     def test_div_rounding_mode(self):
         class TrueDivModule(torch.nn.Module):
             def forward(self, x, y):
@@ -2940,6 +2946,7 @@
             torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
         )
 
+    @skipDtypeChecking
     def test_pow(self):
         class PowModule(torch.nn.Module):
             def forward(self, x, y):
@@ -2986,6 +2993,7 @@
     # add to(dtype=torch.long) to avoid ORT output type does not match expected type.
     # will be fixed in ONNX version 14.
     @skipIfUnsupportedMaxOpsetVersion(13)
+    @skipDtypeChecking
     def test_arithmeticOps_with_low_precision(self):
         class AddModule(torch.nn.Module):
             def forward(self, x, y):
@@ -5279,6 +5287,7 @@
         ind = torch.tensor(-2, dtype=torch.long)
         self.run_test(GetItemModel(), (x, y, z, ind))
 
+    @skipDtypeChecking
     def test_item(self):
         class M(torch.nn.Module):
             def forward(self, x, y, i: int):
@@ -6085,6 +6094,7 @@
         self.run_test(ZeroAndOnes(), (x,))
 
     @skipIfUnsupportedMinOpsetVersion(9)
+    @skipShapeChecking
     def test_tolist(self):
         class List(torch.jit.ScriptModule):
             @torch.jit.script_method
@@ -6626,6 +6636,7 @@
         self.run_test(FullLikeModel(), x)
 
     @skipIfUnsupportedMinOpsetVersion(9)
+    @skipDtypeChecking
     def test_full_like_value(self):
         class FullLikeModel(torch.nn.Module):
             def forward(self, x, y):
@@ -7892,6 +7903,7 @@
         self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})
 
     @skipIfUnsupportedMinOpsetVersion(12)
+    @skipDtypeChecking
     def test_outer(self):
         class Outer(torch.nn.Module):
             def forward(self, x, y):
@@ -11060,6 +11072,7 @@
         self.run_test(model, (boxes, scores))
 
     @skipIfUnsupportedMinOpsetVersion(11)
+    @skipDtypeChecking
     def test_symbolic_shape_inference_arange_2(self):
         # test Range
         class ArangeModel(torch.nn.Module):
@@ -11516,6 +11529,7 @@
         x = torch.ones(12, 3)
         self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})
 
+    @skipShapeChecking
     def test_sum_empty_tensor(self):
         class M(torch.nn.Module):
             def forward(self, x):
diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py
index 8af3e80a..5a155b6 100644
--- a/torch/onnx/verification.py
+++ b/torch/onnx/verification.py
@@ -115,14 +115,34 @@
     return ort_session
 
 
-def _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol):
+def _compare_ort_pytorch_outputs(
+    ort_outs: Sequence[np.ndarray],
+    pt_outs: Sequence[torch.Tensor],
+    rtol: float,
+    atol: float,
+    check_shape: bool,
+    check_dtype: bool,
+):
     pt_outs, _ = torch.jit._flatten(pt_outs)
     pt_outs = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
 
-    assert len(pt_outs) == len(ort_outs), "number of outputs differ"
+    assert len(ort_outs) == len(
+        pt_outs
+    ), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs)})"
 
     for ort_out, pt_out in zip(ort_outs, pt_outs):
-        np.testing.assert_allclose(ort_out, pt_out, rtol=rtol, atol=atol)
+        # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
+        if not check_shape:
+            # Allow different but broadcastable output shapes.
+            ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
+        torch.testing.assert_close(
+            ort_out,
+            pt_out,
+            rtol=rtol,
+            atol=atol,
+            check_dtype=check_dtype,
+            equal_nan=True,
+        )
 
 
 def _prepare_input_for_pytorch(args, kwargs):
@@ -221,6 +241,8 @@
     flatten,
     rtol,
     atol,
+    check_shape,
+    check_dtype,
 ):
     """Compare outputs from ONNX model runs with outputs from PyTorch model runs.
 
@@ -242,7 +264,9 @@
         )
         ort_outs = _run_ort(ort_session, ort_inputs)
 
-        _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol)
+        _compare_ort_pytorch_outputs(
+            ort_outs, pt_outs, rtol, atol, check_shape, check_dtype
+        )
 
     compare_ort_pytorch_model_with_input(input_args, input_kwargs)
 
@@ -519,6 +543,8 @@
     additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
     remained_onnx_input_idx: Optional[Sequence[int]] = None,
     flatten: bool = True,
+    check_shape: bool = True,
+    check_dtype: bool = True,
     ort_providers: Sequence[str] = _ORT_PROVIDERS,
     rtol: float = 0.001,
     atol: float = 1e-7,
@@ -552,6 +578,11 @@
             inputs into a flattened list of Tensors for ONNX. Set this to False if nested
             structures are to be preserved for ONNX, which is usually the case with
             exporting ScriptModules.
+        check_shape (bool, optional): Default True. If True, check the shapes between
+            PyTorch and ONNX Runtime outputs are exactly the same. Set this to False to allow
+            output shape broadcasting.
+        check_dtype (bool, optional): Default True. If True, check the dtypes between
+            PyTorch and ONNX Runtime outputs are consistent.
         ort_providers (sequence, optional): ONNX Runtime providers to use.
         rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
         atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
@@ -601,4 +632,6 @@
             flatten,
             rtol,
             atol,
+            check_shape,
+            check_dtype,
         )