| # Owner(s): ["module: onnx"] |
| |
| from __future__ import annotations |
| |
| import os |
| from typing import Any, Mapping, Type |
| |
| import onnxruntime |
| import pytorch_test_common |
| |
| import torch |
| from torch.onnx import _constants, verification |
| |
| onnx_model_dir = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), |
| os.pardir, |
| "repos", |
| "onnx", |
| "onnx", |
| "backend", |
| "test", |
| "data", |
| ) |
| |
| |
| pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted") |
| |
| |
| pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator") |
| |
| _ORT_PROVIDERS = ("CPUExecutionProvider",) |
| |
| |
| def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): |
| kwargs["ort_providers"] = _ORT_PROVIDERS |
| kwargs["opset_version"] = test_suite.opset_version |
| kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs |
| if hasattr(test_suite, "check_shape"): |
| kwargs["check_shape"] = test_suite.check_shape |
| if hasattr(test_suite, "check_dtype"): |
| kwargs["check_dtype"] = test_suite.check_dtype |
| return verification.verify(*args, **kwargs) |
| |
| |
| def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): |
| """Combine class name with the parameterized arguments. |
| |
| This function is passed to `parameterized.parameterized_class` as the |
| `class_name_func` argument. |
| """ |
| suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items()) |
| return f"{cls.__name__}_{suffix}" |
| |
| |
| class _TestONNXRuntime(pytorch_test_common.ExportTestCase): |
| opset_version = _constants.ONNX_DEFAULT_OPSET |
| keep_initializers_as_inputs = True # For IR version 3 type export. |
| is_script = False |
| check_shape = True |
| check_dtype = True |
| |
| def setUp(self): |
| super().setUp() |
| onnxruntime.set_seed(0) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(0) |
| os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" |
| self.is_script_test_enabled = True |
| |
| # The exported ONNX model may have less inputs than the pytorch model because of const folding. |
| # This mostly happens in unit test, where we widely use torch.size or torch.shape. |
| # So the output is only dependent on the input shape, not value. |
| # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model. |
| def run_test( |
| self, |
| model, |
| input_args, |
| input_kwargs=None, |
| rtol=1e-3, |
| atol=1e-7, |
| do_constant_folding=True, |
| dynamic_axes=None, |
| additional_test_inputs=None, |
| input_names=None, |
| output_names=None, |
| fixed_batch_size=False, |
| training=torch.onnx.TrainingMode.EVAL, |
| remained_onnx_input_idx=None, |
| verbose=False, |
| ): |
| def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True): |
| return run_model_test( |
| self, |
| m, |
| input_args=input_args, |
| input_kwargs=input_kwargs, |
| rtol=rtol, |
| atol=atol, |
| do_constant_folding=do_constant_folding, |
| dynamic_axes=dynamic_axes, |
| additional_test_inputs=additional_test_inputs, |
| input_names=input_names, |
| output_names=output_names, |
| fixed_batch_size=fixed_batch_size, |
| training=training, |
| remained_onnx_input_idx=remained_onnx_input_idx, |
| flatten=flatten, |
| ignore_none=ignore_none, |
| verbose=verbose, |
| ) |
| |
| if isinstance(remained_onnx_input_idx, dict): |
| scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"] |
| tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"] |
| else: |
| scripting_remained_onnx_input_idx = remained_onnx_input_idx |
| tracing_remained_onnx_input_idx = remained_onnx_input_idx |
| |
| is_model_script = isinstance( |
| model, (torch.jit.ScriptModule, torch.jit.ScriptFunction) |
| ) |
| |
| if self.is_script_test_enabled and self.is_script: |
| script_model = model if is_model_script else torch.jit.script(model) |
| _run_test( |
| script_model, |
| scripting_remained_onnx_input_idx, |
| flatten=False, |
| ignore_none=False, |
| ) |
| if not is_model_script and not self.is_script: |
| _run_test(model, tracing_remained_onnx_input_idx) |