[ONNX] refactor test_pytorch_onnx_onnxruntime_cuda.py (#84218)
Fix #80037
After https://github.com/pytorch/pytorch/pull/79641, the code was outdated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84218
Approved by: https://github.com/BowenBao
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
index 3832a11..193b87a 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
@@ -2,7 +2,10 @@
import unittest
+import onnx_test_common
+
import onnxruntime # noqa: F401
+import parameterized
import torch
from pytorch_test_common import (
@@ -11,18 +14,22 @@
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
)
-from test_pytorch_onnx_onnxruntime import TestONNXRuntime
+from test_pytorch_onnx_onnxruntime import (
+ _parameterized_class_attrs_and_values,
+ MAX_ONNX_OPSET_VERSION,
+ MIN_ONNX_OPSET_VERSION,
+)
from torch.cuda.amp import autocast
-from torch.onnx._globals import GLOBALS
from torch.testing._internal import common_utils
-class TestONNXRuntime_cuda(common_utils.TestCase):
-
- opset_version = GLOBALS.export_onnx_opset_version
- keep_initializers_as_inputs = True
- onnx_shape_inference = True
-
+@parameterized.parameterized_class(
+ **_parameterized_class_attrs_and_values(
+ MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION
+ ),
+ class_name_func=onnx_test_common.parameterize_class_name,
+)
+class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
@skipIfUnsupportedMinOpsetVersion(9)
@skipIfNoCuda
def test_gelu_fp16(self):
@@ -145,8 +152,5 @@
self.run_test(Model(), (x, y))
-TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
-TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
-
if __name__ == "__main__":
common_utils.run_tests()