| # Owner(s): ["module: onnx"] |
| |
| import functools |
| import os |
| import random |
| import sys |
| import unittest |
| from typing import Optional |
| |
| import numpy as np |
| |
| import torch |
| from torch.autograd import function |
| from torch.onnx._internal import diagnostics |
| from torch.testing._internal import common_utils |
| |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.insert(-1, pytorch_test_dir) |
| |
| torch.set_default_tensor_type("torch.FloatTensor") |
| |
| BATCH_SIZE = 2 |
| |
| RNN_BATCH_SIZE = 7 |
| RNN_SEQUENCE_LENGTH = 11 |
| RNN_INPUT_SIZE = 5 |
| RNN_HIDDEN_SIZE = 3 |
| |
| |
| def _skipper(condition, reason): |
| def decorator(f): |
| @functools.wraps(f) |
| def wrapper(*args, **kwargs): |
| if condition(): |
| raise unittest.SkipTest(reason) |
| return f(*args, **kwargs) |
| |
| return wrapper |
| |
| return decorator |
| |
| |
| skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available") |
| |
| skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis") |
| |
| skipIfNoBFloat16Cuda = _skipper( |
| lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available" |
| ) |
| |
| # skips tests for all versions below min_opset_version. |
| # if exporting the op is only supported after a specific version, |
| # add this wrapper to prevent running the test for opset_versions |
| # smaller than the currently tested opset_version |
| def skipIfUnsupportedMinOpsetVersion(min_opset_version): |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if self.opset_version < min_opset_version: |
| raise unittest.SkipTest( |
| f"Unsupported opset_version: {self.opset_version} < {min_opset_version}" |
| ) |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| # skips tests for all versions above max_opset_version. |
| def skipIfUnsupportedMaxOpsetVersion(max_opset_version): |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if self.opset_version > max_opset_version: |
| raise unittest.SkipTest( |
| f"Unsupported opset_version: {self.opset_version} > {max_opset_version}" |
| ) |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| # skips tests for all opset versions. |
| def skipForAllOpsetVersions(): |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if self.opset_version: |
| raise unittest.SkipTest( |
| "Skip verify test for unsupported opset_version" |
| ) |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| def skipTraceTest(skip_before_opset_version: Optional[int] = None, reason: str = ""): |
| """Skip tracing test for opset version less than skip_before_opset_version. |
| |
| Args: |
| skip_before_opset_version: The opset version before which to skip tracing test. |
| If None, tracing test is always skipped. |
| reason: The reason for skipping tracing test. |
| |
| Returns: |
| A decorator for skipping tracing test. |
| """ |
| |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if skip_before_opset_version is not None: |
| self.skip_this_opset = self.opset_version < skip_before_opset_version |
| else: |
| self.skip_this_opset = True |
| if self.skip_this_opset and not self.is_script: |
| raise unittest.SkipTest(f"Skip verify test for torch trace. {reason}") |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| def skipScriptTest(skip_before_opset_version: Optional[int] = None, reason: str = ""): |
| """Skip scripting test for opset version less than skip_before_opset_version. |
| |
| Args: |
| skip_before_opset_version: The opset version before which to skip scripting test. |
| If None, scripting test is always skipped. |
| reason: The reason for skipping scripting test. |
| |
| Returns: |
| A decorator for skipping scripting test. |
| """ |
| |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if skip_before_opset_version is not None: |
| self.skip_this_opset = self.opset_version < skip_before_opset_version |
| else: |
| self.skip_this_opset = True |
| if self.skip_this_opset and self.is_script: |
| raise unittest.SkipTest(f"Skip verify test for TorchScript. {reason}") |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| # skips tests for opset_versions listed in unsupported_opset_versions. |
| # if the caffe2 test cannot be run for a specific version, add this wrapper |
| # (for example, an op was modified but the change is not supported in caffe2) |
| def skipIfUnsupportedOpsetVersion(unsupported_opset_versions): |
| def skip_dec(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if self.opset_version in unsupported_opset_versions: |
| raise unittest.SkipTest( |
| "Skip verify test for unsupported opset_version" |
| ) |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| return skip_dec |
| |
| |
| def skipShapeChecking(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| self.check_shape = False |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| |
| def skipDtypeChecking(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| self.check_dtype = False |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| |
| def flatten(x): |
| return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x)) |
| |
| |
| def set_rng_seed(seed): |
| torch.manual_seed(seed) |
| random.seed(seed) |
| np.random.seed(seed) |
| |
| |
| class ExportTestCase(common_utils.TestCase): |
| """Test case for ONNX export. |
| |
| Any test case that tests functionalities under torch.onnx should inherit from this class. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # TODO(#88264): Flaky test failures after changing seed. |
| set_rng_seed(0) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(0) |
| diagnostics.engine.clear() |