[ONNX] Add dtype check in onnx verification (#79263)
Currently we don't have a dtype check in verifying the consistency between PyTorch and ONNX outputs. As a result, some of dtype inconsistencies were found and reported: #77842 #77845
This is a POC.
Failed workflows:
- [linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)]
- inconsistent shape
- TestONNXRuntime_opset10.test_all (#79371)
- TestONNXRuntime_opset10.test_any (#79371)
- TestONNXRuntime_opset10.test_argmin_argmax (#79503)
- TestONNXRuntime_opset10.test_hardshrink (#79695)
- TestONNXRuntime_opset10.test_linalg_norm (#79506)
- TestONNXRuntime_opset10.test_linalg_vector_norm (#79506)
- TestONNXRuntime_opset10.test_prelu_scalar (#79846)
- TestONNXRuntime_opset10.test_softshrink (#79695)
- TestONNXRuntime_opset10.test_sum_empty_tensor (skipped)
- TestONNXRuntime_opset10.test_tolist (skipped)
- inconsistent dtype
- test_arithmetic_prim_bool (skipped)
- test_arithmeticOps_with_low_precision (skipped)
- test_arithmetic_prim_float (skipped)
- test_logical_and (#79339)
- test_logical_or (#79339)
- test_logical_xor (#79339)
- test_pow (skipped)
- test_primitive_input_floating (skipped)
- test_quantize_per_tensor (#79690)
- test_quantized_adaptive_avg_pool2d (#79690)
- test_quantized_arithmetic (#79690)
- test_quantized_arithmetic_qfunctional (#79690)
- test_quantized_conv2d (#79690)
- test_quantized_conv2d_relu (#79690)
- test_quantized_flatten (#79690)
- test_quantized_hardsigmoid (#79690)
- test_quantized_hardswish (#79690)
- test_quantized_linear (#79690)
- test_quantized_sigmoid (#79690)
- test_item (skipped)
- test_full_like_value (skipped)
- TestONNXRuntime_opset7.test_div_rounding_mode (skipped)
- TestONNXRuntime_opset8.test_div_rounding_mode (skipped)
- TestONNXRuntime_opset9.test_div_rounding_mode (skipped)
- TestONNXRuntime_opset9_IRv4.test_div_rounding_mode (skipped)
- test_outer (skipped)
- test_symbolic_shape_inference_arange_2 (skipped)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79263
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index 9c617c3..fe0614f 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -37,6 +37,10 @@
kwargs["ort_providers"] = _ORT_PROVIDERS
kwargs["opset_version"] = test_suite.opset_version
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
+ if hasattr(test_suite, "check_shape"):
+ kwargs["check_shape"] = test_suite.check_shape
+ if hasattr(test_suite, "check_dtype"):
+ kwargs["check_dtype"] = test_suite.check_dtype
return verification.verify(*args, **kwargs)
@@ -60,6 +64,8 @@
opset_version = _constants.onnx_default_opset
keep_initializers_as_inputs = True # For IR version 3 type export.
is_script = False
+ check_shape = True
+ check_dtype = True
def setUp(self):
set_rng_seed(0)
diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py
index 77bdd28..67b3905 100644
--- a/test/onnx/pytorch_test_common.py
+++ b/test/onnx/pytorch_test_common.py
@@ -139,5 +139,23 @@
return skip_dec
+def skipShapeChecking(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ self.check_shape = False
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+def skipDtypeChecking(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ self.check_dtype = False
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py
index a850779..4794fb5 100644
--- a/test/onnx/test_pytorch_jit_onnx.py
+++ b/test/onnx/test_pytorch_jit_onnx.py
@@ -50,6 +50,8 @@
opset_version = -1 # Sub-classes must override
ort_providers = ["CPUExecutionProvider"]
+ check_shape = True
+ check_dtype = True
def run_test(self, graph_ir, example_inputs):
graph = torch._C.parse_ir(graph_ir)
@@ -64,7 +66,12 @@
ort_outs = verification._run_ort(ort_sess, example_inputs)
verification._compare_ort_pytorch_outputs(
- ort_outs, jit_outs, rtol=1e-3, atol=1e-7
+ ort_outs,
+ jit_outs,
+ rtol=1e-3,
+ atol=1e-7,
+ check_shape=self.check_shape,
+ check_dtype=self.check_dtype,
)
def test_example_ir(self):
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 2905345..0159248 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -26,11 +26,13 @@
RNN_HIDDEN_SIZE,
RNN_INPUT_SIZE,
RNN_SEQUENCE_LENGTH,
+ skipDtypeChecking,
skipForAllOpsetVersions,
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
skipIfUnsupportedOpsetVersion,
skipScriptTest,
+ skipShapeChecking,
skipTraceTest,
)
from torch import Tensor
@@ -827,6 +829,7 @@
y = torch.randint(10, (2, 3, 4))
self.run_test(Model(), (x, y))
+ @skipDtypeChecking
def test_primitive_input_floating(self):
class Model(torch.nn.Module):
def __init__(self):
@@ -1531,6 +1534,7 @@
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
+ @skipDtypeChecking
def test_arithmetic_prim_float(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: float):
@@ -1553,6 +1557,7 @@
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
+ @skipDtypeChecking
def test_arithmetic_prim_bool(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: int, z: bool, t: float):
@@ -1720,6 +1725,7 @@
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y))
+ @skipDtypeChecking
def test_div_rounding_mode(self):
class TrueDivModule(torch.nn.Module):
def forward(self, x, y):
@@ -2940,6 +2946,7 @@
torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
)
+ @skipDtypeChecking
def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
@@ -2986,6 +2993,7 @@
# add to(dtype=torch.long) to avoid ORT output type does not match expected type.
# will be fixed in ONNX version 14.
@skipIfUnsupportedMaxOpsetVersion(13)
+ @skipDtypeChecking
def test_arithmeticOps_with_low_precision(self):
class AddModule(torch.nn.Module):
def forward(self, x, y):
@@ -5279,6 +5287,7 @@
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))
+ @skipDtypeChecking
def test_item(self):
class M(torch.nn.Module):
def forward(self, x, y, i: int):
@@ -6085,6 +6094,7 @@
self.run_test(ZeroAndOnes(), (x,))
@skipIfUnsupportedMinOpsetVersion(9)
+ @skipShapeChecking
def test_tolist(self):
class List(torch.jit.ScriptModule):
@torch.jit.script_method
@@ -6626,6 +6636,7 @@
self.run_test(FullLikeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
+ @skipDtypeChecking
def test_full_like_value(self):
class FullLikeModel(torch.nn.Module):
def forward(self, x, y):
@@ -7892,6 +7903,7 @@
self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})
@skipIfUnsupportedMinOpsetVersion(12)
+ @skipDtypeChecking
def test_outer(self):
class Outer(torch.nn.Module):
def forward(self, x, y):
@@ -11060,6 +11072,7 @@
self.run_test(model, (boxes, scores))
@skipIfUnsupportedMinOpsetVersion(11)
+ @skipDtypeChecking
def test_symbolic_shape_inference_arange_2(self):
# test Range
class ArangeModel(torch.nn.Module):
@@ -11516,6 +11529,7 @@
x = torch.ones(12, 3)
self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})
+ @skipShapeChecking
def test_sum_empty_tensor(self):
class M(torch.nn.Module):
def forward(self, x):
diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py
index 8af3e80a..5a155b6 100644
--- a/torch/onnx/verification.py
+++ b/torch/onnx/verification.py
@@ -115,14 +115,34 @@
return ort_session
-def _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol):
+def _compare_ort_pytorch_outputs(
+ ort_outs: Sequence[np.ndarray],
+ pt_outs: Sequence[torch.Tensor],
+ rtol: float,
+ atol: float,
+ check_shape: bool,
+ check_dtype: bool,
+):
pt_outs, _ = torch.jit._flatten(pt_outs)
pt_outs = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
- assert len(pt_outs) == len(ort_outs), "number of outputs differ"
+ assert len(ort_outs) == len(
+ pt_outs
+ ), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs)})"
for ort_out, pt_out in zip(ort_outs, pt_outs):
- np.testing.assert_allclose(ort_out, pt_out, rtol=rtol, atol=atol)
+ # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
+ if not check_shape:
+ # Allow different but broadcastable output shapes.
+ ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
+ torch.testing.assert_close(
+ ort_out,
+ pt_out,
+ rtol=rtol,
+ atol=atol,
+ check_dtype=check_dtype,
+ equal_nan=True,
+ )
def _prepare_input_for_pytorch(args, kwargs):
@@ -221,6 +241,8 @@
flatten,
rtol,
atol,
+ check_shape,
+ check_dtype,
):
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
@@ -242,7 +264,9 @@
)
ort_outs = _run_ort(ort_session, ort_inputs)
- _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol)
+ _compare_ort_pytorch_outputs(
+ ort_outs, pt_outs, rtol, atol, check_shape, check_dtype
+ )
compare_ort_pytorch_model_with_input(input_args, input_kwargs)
@@ -519,6 +543,8 @@
additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
remained_onnx_input_idx: Optional[Sequence[int]] = None,
flatten: bool = True,
+ check_shape: bool = True,
+ check_dtype: bool = True,
ort_providers: Sequence[str] = _ORT_PROVIDERS,
rtol: float = 0.001,
atol: float = 1e-7,
@@ -552,6 +578,11 @@
inputs into a flattened list of Tensors for ONNX. Set this to False if nested
structures are to be preserved for ONNX, which is usually the case with
exporting ScriptModules.
+ check_shape (bool, optional): Default True. If True, check the shapes between
+ PyTorch and ONNX Runtime outputs are exactly the same. Set this to False to allow
+ output shape broadcasting.
+ check_dtype (bool, optional): Default True. If True, check the dtypes between
+ PyTorch and ONNX Runtime outputs are consistent.
ort_providers (sequence, optional): ONNX Runtime providers to use.
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
@@ -601,4 +632,6 @@
flatten,
rtol,
atol,
+ check_shape,
+ check_dtype,
)