blob: 8a499f1f1465b0ebefb73c80c6fd7696bedbe821 [file] [log] [blame]
# Owner(s): ["module: onnx"]
"""Test consistency between the output values of torch.onnx exported operators
and torch operators given the same inputs.
Usage:
pytest test/onnx/test_op_consistancy.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test/onnx/test_op_consistancy.py -k ceil
Read more on Running and writing tests:
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
Note:
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
ALLOWLIST_OP lists. See "Modify this section"
"""
import copy
import dataclasses
import unittest
from typing import (
AbstractSet,
Callable,
Collection,
Iterable,
Optional,
Sequence,
Tuple,
Union,
)
import onnx_test_common
import torch
from torch.onnx import _constants
from torch.testing._internal import (
common_device_type,
common_methods_invocations,
common_utils,
)
from torch.testing._internal.opinfo import core as opinfo_core
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
QINT_TYPES = (
torch.qint8,
torch.quint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
COMPLEX_TYPES = (
torch.complex32,
torch.complex64,
torch.complex128,
)
SUPPORTED_DTYPES = (
# Boolean
torch.bool,
# Integers
*INT_TYPES,
# Floating types
*FLOAT_TYPES,
)
@dataclasses.dataclass
class DecorateMeta:
"""Information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
Attributes:
op_name: The name of the operator.
variant_name: The name of the OpInfo variant.
decorator: The decorator to apply to the test case.
opsets: The opsets to apply the decorator to.
dtypes: The dtypes to apply the decorator to.
reason: The reason for skipping.
"""
op_name: str
variant_name: str
decorator: Callable
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
dtypes: Optional[Collection[torch.dtype]]
reason: str
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
return True
return any(
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
for opset_spec in self.opsets
)
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Expects a OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def dont_care(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Skips a test case in OpInfo that we don't care about.
Likely because ONNX does not support the use case or it is by design.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Don't care: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def fixme(
op_name: str,
variant_name: str = "",
*,
reason: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
):
"""Skips a test case in OpInfo. It should be eventually fixed.
Args:
op_name: The name of the operator.
variant_name: The name of the variant.
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
dtypes: The dtypes to expect the failure.
reason: The reason for the failure.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"To fix: {reason}"),
opsets=opsets,
dtypes=dtypes,
reason=reason,
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
opset: int,
skip_or_xfails: Iterable[DecorateMeta],
):
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
Args:
all_opinfos: All OpInfos.
test_class_name: The name of the test class.
base_test_name: The name of the test method.
opset: The opset to decorate for.
skip_or_xfails: DecorateMeta's.
"""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
if not decorate_meta.contains_opset(opset):
# Skip does not apply to this opset
continue
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def opsets_before(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is before the specified."""
def compare(other_opset: int):
return other_opset < opset
return compare
def opsets_after(opset: int) -> Callable[[int], bool]:
"""Returns a comparison function that decides if the given opset is after the specified."""
def compare(other_opset: int):
return other_opset > opset
return compare
def reason_onnx_runtime_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
def reason_onnx_does_not_support(
operator: str, dtypes: Optional[Sequence[str]] = None
) -> str:
"""Formats the reason: ONNX doesn't support the given dtypes."""
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
def reason_jit_tracer_error(info: str) -> str:
"""Formats the reason: JIT tracer errors."""
return f"JIT tracer error on {info}"
def reason_flaky() -> str:
"""Formats the reason: test is flaky."""
return "flaky test"
# Modify this section ##########################################################
# NOTE: Modify this section as more ops are supported. The list should be sorted
# alphabetically.
#
# For example, to add a test for torch.ceil:
# 1. Add "ceil" to ALLOWLIST_OP then run pytest.
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
# Ops to be tested for numerical consistency between onnx and pytorch
ALLOWLIST_OP: AbstractSet[str] = frozenset(
[
"ceil",
"sqrt",
"t",
]
)
# fmt: off
# Turn off black formatting to keep the list compact
# Expected failures for onnx export.
# The list should be sorted alphabetically by op name.
# Q: When should I use fixme vs vs dont_care vs xfail?
# A: Use fixme when we want to fix the test eventually but it doesn't fail consistently,
# e.g. the test is flaky or some tests pass. Otherwise, use xfail.
# Use dont_care if we don't care about the test passing, e.g. ONNX doesn't support the usage.
# Use xfail if a test fails now and we want to eventually fix the test.
EXPECTED_SKIPS_OR_FAILS: Tuple[DecorateMeta, ...] = (
dont_care(
"ceil", dtypes=BOOL_TYPES + INT_TYPES,
reason=reason_onnx_does_not_support("Ceil")
),
fixme("ceil", dtypes=[torch.float64], reason=reason_onnx_runtime_does_not_support("Ceil", ["f64"])),
dont_care("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
)
# fmt: on
# END OF SECTION TO MODIFY #####################################################
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
class SingleOpModel(torch.nn.Module):
"""Test model to wrap around a single op for export."""
def __init__(self, op, kwargs):
super().__init__()
self.operator = op
self.kwargs = kwargs
def forward(self, *args):
return self.operator(*args, **self.kwargs)
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"""Test output consistency between exported ONNX models and PyTorch eager mode.
This is a parameterized test suite.
"""
@classmethod
def create_test_base(cls, opset: int):
"""Returns the base test method for the given opset."""
def _output_match_base(self, device: str, dtype: torch.dtype, op):
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
for (i, cpu_sample) in enumerate(samples):
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with self.subTest(
opset=opset,
sample_num=i,
input=repr(cpu_sample.input),
args=repr(cpu_sample.args),
kwargs=repr(cpu_sample.kwargs),
):
model = SingleOpModel(op, cpu_sample.kwargs)
model.eval()
# Run the test
inputs = (cpu_sample.input, *cpu_sample.args)
self.run_test(model, inputs)
test_name = f"test_output_match_opset_{opset}"
_output_match_base.__name__ = test_name
return _output_match_base
@classmethod
def parameterize_opsets(cls, opsets: Sequence[int]):
"""Parametrizes the TestOnnxModelOutputConsistency class with the given opsets."""
for opset in opsets:
# Generate a test method for each opset
base_method = cls.create_test_base(opset)
# Important to rename the test method so that DecorateInfo can find it
test_name = base_method.__name__
# Update the ops to skip in the OpInfo database
add_decorate_info(
OPS_DB,
cls.__name__,
test_name,
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
# Create parameterized tests for each op
filtered_ops = [op for op in OPS_DB if op.name in ALLOWLIST_OP]
decorated = common_device_type.ops(
filtered_ops,
allowed_dtypes=SUPPORTED_DTYPES,
)(base_method)
setattr(cls, test_name, decorated)
TestOnnxModelOutputConsistency.parameterize_opsets(TESTED_OPSETS)
common_device_type.instantiate_device_type_tests(
TestOnnxModelOutputConsistency, globals(), only_for="cpu"
)
if __name__ == "__main__":
common_utils.run_tests()