| # Owner(s): ["module: onnx"] |
| |
| """Test consistency between the output values of torch.onnx FX exported operators |
| and torch operators given the same inputs. |
| |
| Usage: |
| |
| 1. Test all operators: |
| |
| pytest test/onnx/test_fx_op_consistency.py |
| |
| 2. To run tests on a specific operator (e.g. torch.ceil): |
| |
| pytest test/onnx/test_fx_op_consistency.py -k ceil |
| pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention |
| |
| 3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. |
| |
| CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int |
| |
| NOTE: Read more on Running and writing tests: |
| https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests |
| |
| Note: |
| |
| 1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. |
| |
| 2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. |
| |
| 3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and |
| TESTED_OPS lists. See "Modify this section" |
| |
| """ |
| |
| from __future__ import annotations |
| |
| import copy |
| import itertools |
| import os |
| from typing import ( |
| Any, |
| Callable, |
| Collection, |
| List, |
| Mapping, |
| Optional, |
| Tuple, |
| Type, |
| Union, |
| ) |
| |
| import error_reproduction |
| |
| import onnx_test_common |
| |
| import parameterized |
| import pytest |
| import pytorch_test_common |
| |
| import torch |
| from onnx_test_common import skip, skip_slow, xfail |
| from torch.onnx._internal.diagnostics import _rules |
| from torch.testing._internal import ( |
| common_device_type, |
| common_methods_invocations, |
| common_utils, |
| ) |
| from torch.testing._internal.opinfo import core as opinfo_core |
| |
| |
| # NOTE: For ATen signature modifications that will break ONNX export, |
| # use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip |
| # to make the signal apparent for maintainers. |
| def xfail_torchlib_forward_compatibility( |
| op_name: str, |
| variant_name: str = "", |
| *, |
| reason: str, |
| github_issue: str, |
| opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, |
| dtypes: Optional[Collection[torch.dtype]] = None, |
| matcher: Optional[Callable[[Any], bool]] = None, |
| enabled_if: bool = True, |
| ): |
| """Prefer using this (xfail) over skip when possible. |
| |
| Only skip when the test is not failing consistently. |
| """ |
| return xfail( |
| op_name, |
| variant_name=variant_name, |
| reason=f"{reason}. GitHub Issue: {github_issue}", |
| opsets=opsets, |
| dtypes=dtypes, |
| matcher=matcher, |
| enabled_if=enabled_if, |
| ) |
| |
| |
| def skip_torchlib_forward_compatibility( |
| op_name: str, |
| variant_name: str = "", |
| *, |
| reason: str, |
| github_issue: str, |
| opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, |
| dtypes: Optional[Collection[torch.dtype]] = None, |
| matcher: Optional[Callable[[Any], Any]] = None, |
| enabled_if: bool = True, |
| ): |
| """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. |
| |
| Only skip when the test is not failing consistently. |
| """ |
| return skip( |
| op_name, |
| variant_name=variant_name, |
| reason=f"{reason}. GitHub Issue: {github_issue}", |
| opsets=opsets, |
| dtypes=dtypes, |
| matcher=matcher, |
| enabled_if=enabled_if, |
| ) |
| |
| |
| # fmt: off |
| # Turn off black formatting to keep the list compact |
| |
| # Expected failures for onnx export. |
| # The list should be sorted alphabetically by op name. |
| # Q: When should I use fixme vs vs skip vs xfail? |
| # A: Prefer xfail over skip when possible. |
| # 2a. If a test is now failing because of xpass, because some previous errors |
| # are now fixed, removed the corresponding xfail. |
| # 2b. If a test is not failing consistently, use skip. |
| EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( |
| xfail( |
| "__getitem__", |
| reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", |
| ), |
| xfail( |
| "__radd__", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), |
| ), |
| xfail( |
| "__rmatmul__", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "__rpow__", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), |
| ), |
| skip( |
| "_native_batch_norm_legit", |
| dtypes=(torch.float16,), |
| reason="fixme: Assertion error: result mismatch and type error", |
| ), |
| xfail( |
| "_softmax_backward_data", |
| reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)") |
| ), |
| xfail( |
| "add", dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Add") |
| ), |
| xfail( |
| "add", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_script_does_not_support( |
| "Add", "int8, int16, uint8 have type issue." |
| ), |
| ), |
| xfail( |
| "addbmm", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") |
| ), |
| xfail( |
| "addmm", dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Addmm") |
| ), |
| xfail( |
| "addmm", |
| variant_name="decomposed", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Addmm") |
| ), |
| skip( |
| "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") |
| ), |
| skip( |
| "addmm", |
| variant_name="decomposed", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") |
| ), |
| xfail( |
| "addr", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support( |
| "Addr", "bool" |
| ), |
| ), |
| xfail( |
| "addr", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") |
| ), |
| xfail( |
| "all", |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") |
| ), |
| xfail( |
| "allclose", |
| reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") |
| ), |
| xfail( |
| "amax", |
| dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), |
| reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), |
| ), |
| xfail( |
| "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), |
| reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") |
| ), |
| xfail( |
| "aminmax", |
| dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), |
| reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), |
| ), |
| xfail( |
| "any", |
| reason=onnx_test_common.reason_onnx_does_not_support( |
| "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") |
| ), |
| xfail( |
| "arange", |
| dtypes=(torch.uint8,), |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), |
| ), |
| xfail( |
| "arange", |
| dtypes=(torch.int16, torch.int32), |
| reason="AssertionError: The values for attribute 'shape' do not match", |
| ), |
| xfail( |
| "argmax", |
| dtypes=( |
| torch.int16, |
| torch.int64, |
| ), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "ArgMax", "int16, int64" |
| ), |
| ), |
| xfail( |
| "argmin", |
| dtypes=( |
| torch.uint8, |
| torch.int8, |
| torch.int16, |
| torch.int64, |
| ), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "ArgMin", "uint8, int8, int16, int64" |
| ), |
| ), |
| xfail( |
| "argwhere", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| skip( |
| "as_strided", |
| variant_name="partial_views", |
| reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", |
| ), |
| xfail( |
| "atan2", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "baddbmm", |
| dtypes=( |
| torch.uint8, |
| torch.int8, |
| torch.int16, |
| ), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Matmul", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "baddbmm", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") |
| ), |
| xfail( |
| "bernoulli", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "bfloat16", |
| reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", |
| ), |
| xfail( |
| "bincount", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), |
| ), |
| xfail( |
| "bmm", |
| dtypes=( |
| torch.uint8, |
| torch.int8, |
| torch.int16, |
| ), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Matmul", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "broadcast_shapes", |
| reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), |
| ), |
| xfail( |
| "cauchy", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| skip( |
| "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") |
| ), |
| xfail( |
| "chalf", |
| reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." |
| ), |
| xfail( |
| "chunk", dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool") |
| ), |
| xfail( |
| "chunk", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Chunk", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "clamp", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Max", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") |
| ), |
| xfail( |
| "clamp_max", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Max", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "clamp_min", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Max", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") |
| ), |
| xfail( |
| "constant_pad_nd", |
| dtypes=(torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Constant_pad_nd", "int16" |
| ), |
| ), |
| xfail( |
| "constant_pad_nd", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support( |
| "Constant_pad_nd", "complex64" |
| ), |
| ), |
| xfail( |
| "corrcoef", |
| reason=onnx_test_common.reason_dynamo_does_not_support( |
| "aten.equal.default" |
| ), |
| ), |
| xfail( |
| "cov", |
| reason=onnx_test_common.reason_dynamo_does_not_support( |
| "aten.equal.default" |
| ), |
| ), |
| xfail( |
| "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") |
| ), |
| xfail( |
| "combinations", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), |
| ), |
| xfail( |
| "cross", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"), |
| ), |
| xfail( |
| "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") |
| ), |
| skip( |
| "dot", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), |
| ), |
| xfail( |
| "empty", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." |
| ), |
| xfail( |
| "empty_strided", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "eq", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), |
| ), |
| xfail( |
| "equal", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") |
| ), |
| xfail( |
| "exponential", |
| reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), |
| ), |
| xfail( |
| "fft.fft", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.fft2", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.fftn", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.ifft", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.ifft2", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.ifftn", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.irfft", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.irfft2", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "fft.irfftn", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), |
| ), |
| xfail( |
| "fft.rfft", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), |
| ), |
| xfail( |
| "fft.rfftn", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), |
| ), |
| xfail( |
| "fft.rfft2", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), |
| ), |
| xfail( |
| "floor", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), |
| ), |
| xfail( |
| "floor_divide", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), |
| ), |
| xfail( |
| "full", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") |
| ), |
| xfail( |
| "full_like", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") |
| ), |
| xfail( |
| "geometric", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "heaviside", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), |
| ), |
| xfail( |
| "index_fill", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") |
| ), |
| xfail( |
| "index_put", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), |
| ), |
| xfail( |
| "index_put", |
| dtypes=(torch.uint8, torch.int8, torch.int16,), |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), |
| ), |
| xfail( |
| "isnan", |
| dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), |
| ), |
| xfail( |
| "istft", |
| reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), |
| ), |
| xfail( |
| "item", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "lerp", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") |
| ), |
| xfail( |
| "linalg.lstsq", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), |
| ), |
| xfail( |
| "linalg.lstsq", |
| variant_name="grad_oriented", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), |
| ), |
| xfail( |
| "linalg.norm", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "linalg.norm", |
| variant_name="subgradients_at_zero", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "linalg.vecdot", |
| reason="fixme: Assertion error: result shape mismatch", |
| ), |
| xfail( |
| "linspace", |
| dtypes=(torch.int64, torch.int32,), |
| reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", |
| ), |
| xfail( |
| "linspace", |
| variant_name="tensor_overload", |
| dtypes=(torch.int64, torch.int32,), |
| reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", |
| ), |
| xfail( |
| "linspace", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") |
| ), |
| xfail( |
| "linspace", |
| variant_name="tensor_overload", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") |
| ), |
| xfail( |
| "log_normal", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "log_softmax", |
| dtypes=(torch.float16,), |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "log_softmax", |
| variant_name="with_dtype", |
| dtypes=(torch.float16,), |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "logcumsumexp", |
| reason=onnx_test_common.reason_onnx_does_not_support( |
| "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") |
| ), |
| xfail( |
| "logical_and", |
| dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), |
| ), |
| xfail( |
| "logical_not", |
| dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), |
| ), |
| xfail( |
| "logical_or", |
| dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), |
| ), |
| xfail( |
| "logical_xor", |
| dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), |
| ), |
| xfail( |
| "logsumexp", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"), |
| ), |
| xfail( |
| "masked.logsumexp", |
| reason="fixme: https://github.com/onnx/onnx/issues/4986", |
| ), |
| xfail( |
| "masked.amax", |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "masked.amin", |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "masked.argmin", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "masked.argmax", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "masked_fill", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), |
| ), |
| xfail( |
| "masked.sum", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), |
| ), |
| xfail( |
| "masked.log_softmax", |
| dtypes=(torch.float16,), |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "masked.mean", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), |
| ), |
| xfail( |
| "masked.norm", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "masked.prod", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), |
| ), |
| xfail( |
| "masked_select", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), |
| ), |
| xfail( |
| "max", |
| variant_name="reduction_no_dim", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), |
| ), |
| xfail( |
| "max", |
| variant_name="reduction_with_dim", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), |
| ), |
| xfail( |
| "max", |
| variant_name="reduction_with_dim", |
| reason="https://github.com/onnx/onnx/issues/4986", |
| ), |
| xfail( |
| "mean", |
| reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", |
| ), |
| xfail( |
| "min", |
| variant_name="reduction_no_dim", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), |
| ), |
| xfail( |
| "min", |
| variant_name="reduction_with_dim", |
| dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), |
| ), |
| skip( |
| "mm", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), |
| ), |
| xfail( |
| "multinomial", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nanquantile", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") |
| ), |
| xfail( |
| "nansum", |
| dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), |
| ), |
| xfail( |
| "narrow", |
| reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), |
| ), |
| xfail( |
| "native_batch_norm", |
| dtypes=(torch.float16,), |
| reason="fixme: https://github.com/microsoft/onnxscript/issues/1269", |
| ), |
| xfail( |
| "native_layer_norm", |
| dtypes=(torch.float16,), |
| reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "new_full", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") |
| ), |
| xfail( |
| "nn.functional.adaptive_avg_pool2d", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ |
| maximum recursion depth exceeded while calling a Python object"), |
| ), |
| xfail( |
| "nn.functional.adaptive_avg_pool3d", |
| reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), |
| ), |
| xfail( |
| "nn.functional.alpha_dropout", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.avg_pool1d", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), |
| ), |
| xfail( |
| "nn.functional.avg_pool2d", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), |
| ), |
| xfail( |
| "nn.functional.avg_pool3d", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), |
| ), |
| xfail( |
| "nn.functional.batch_norm", |
| dtypes=(torch.float16,), |
| reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", |
| ), |
| xfail( |
| "nn.functional.conv_transpose1d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), |
| ), |
| xfail( |
| "nn.functional.conv_transpose2d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), |
| ), |
| xfail( |
| "nn.functional.conv_transpose3d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), |
| ), |
| skip( |
| "nn.functional.conv_transpose1d", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| skip( |
| "nn.functional.conv_transpose2d", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| skip( |
| "nn.functional.conv_transpose3d", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.conv1d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), |
| ), |
| xfail( |
| "nn.functional.conv2d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), |
| ), |
| xfail( |
| "nn.functional.conv2d", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.conv3d", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), |
| ), |
| xfail( |
| "nn.functional.conv3d", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.ctc_loss", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), |
| ), |
| xfail( |
| "nn.functional.dropout", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.dropout2d", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.dropout3d", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.feature_alpha_dropout", |
| variant_name="with_train", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.feature_alpha_dropout", |
| variant_name="without_train", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.fractional_max_pool2d", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.fractional_max_pool3d", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.gaussian_nll_loss", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), |
| ), |
| xfail( |
| "nn.functional.grid_sample", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.group_norm", |
| dtypes=(torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), |
| ), |
| xfail( |
| "nn.functional.local_response_norm", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), |
| ), |
| xfail( |
| "nn.functional.linear", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), |
| ), |
| xfail( |
| "nn.functional.max_pool2d", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), |
| ), |
| xfail( |
| "nn.functional.max_pool3d", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), |
| ), |
| xfail( |
| "nn.functional.multi_head_attention_forward", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.one_hot", |
| reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), |
| ), |
| xfail( |
| "nn.functional.pad", |
| variant_name="replicate", |
| reason="fixme: ORT error: padding size", |
| ), |
| xfail( |
| "nn.functional.pad", |
| variant_name="replicate_negative", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.pad", |
| variant_name="reflect", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.rrelu", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "nn.functional.rrelu", |
| dtypes=(torch.int64,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), |
| ), |
| skip( |
| "nn.functional.scaled_dot_product_attention", |
| matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, |
| reason="dropout is random so the results do not match", |
| ), |
| xfail( |
| "nn.functional.scaled_dot_product_attention", |
| dtypes=(torch.float16,), |
| reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", |
| ), |
| xfail( |
| "nn.functional.selu", |
| reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", |
| ), |
| xfail( |
| "nn.functional.soft_margin_loss", |
| dtypes=(torch.float16,), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nn.functional.tanhshrink", |
| dtypes=(torch.float16,), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "nonzero", |
| dtypes=(torch.int8, torch.int16), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), |
| ), |
| xfail( |
| "normal", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "normal", |
| variant_name="in_place", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "normal", |
| variant_name="number_mean", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "ones", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." |
| ), |
| xfail( |
| "pca_lowrank", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "quantile", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") |
| ), |
| xfail( |
| "rand_like", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "randint", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "randint_like", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "randn", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "randn_like", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "resize_", |
| reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") |
| ), |
| xfail( |
| "resize_as_", |
| reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") |
| ), |
| xfail( |
| "round", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), |
| ), |
| xfail( |
| "rsub", |
| dtypes=(torch.uint8, torch.int8, torch.int16), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Mul", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "scatter_add", |
| dtypes=(torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), |
| ), |
| xfail( |
| "scatter_reduce", |
| variant_name="sum", |
| dtypes=(torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), |
| ), |
| xfail( |
| "scatter_reduce", |
| variant_name="prod", |
| dtypes=(torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), |
| ), |
| xfail( |
| "scatter_reduce", |
| variant_name="amin", |
| dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), |
| ), |
| xfail( |
| "scatter_reduce", |
| variant_name="amax", |
| dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), |
| ), |
| xfail( |
| "scatter_reduce", |
| variant_name="mean", |
| reason="ONNX doesn't support reduce='mean' option", |
| ), |
| xfail( |
| "sign", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), |
| ), |
| xfail( |
| "signal.windows.kaiser", |
| reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), |
| ), |
| xfail( |
| "softmax", |
| dtypes=(torch.float16,), |
| reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" |
| ), |
| xfail( |
| "sparse.mm", |
| variant_name="reduce", |
| reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), |
| ), |
| xfail( |
| "sparse.sampled_addmm", |
| reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), |
| ), |
| xfail( |
| "special.erfcx", |
| dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), |
| ), |
| xfail( |
| "special.erfcx", |
| dtypes=onnx_test_common.FLOAT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), |
| ), |
| xfail( |
| "special.ndtr", |
| dtypes=(torch.float16,), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "split", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "split", |
| variant_name="list_args", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "split_with_sizes", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "square", |
| dtypes=(torch.int8, torch.uint8, torch.int16), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), |
| ), |
| xfail( |
| "squeeze", |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "squeeze", |
| variant_name="multiple", |
| reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", |
| ), |
| xfail( |
| "svd_lowrank", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "std_mean", |
| reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." |
| ), |
| xfail( |
| "std_mean", |
| variant_name="unbiased", |
| reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." |
| ), |
| xfail( |
| "stft", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), |
| ), |
| xfail( |
| "sub", |
| dtypes=(torch.uint8, torch.int8, torch.int16), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Mul", "uint8, int8, int16" |
| ), |
| ), |
| xfail( |
| "take", |
| reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), |
| ), |
| xfail( |
| "tensor_split", |
| reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), |
| ), |
| xfail( |
| "to", |
| dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.bool, torch.complex64), |
| # model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, |
| reason="This op requires torch.dtype as input, which is not supported currently.", |
| ), |
| xfail( |
| "topk", |
| dtypes=(torch.int64, torch.int32), |
| reason="fixme: Assertion error: result mismatch", |
| ), |
| xfail( |
| "tril", |
| dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), |
| ), |
| xfail( |
| "triu", |
| dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), |
| ), |
| xfail( |
| "trunc", |
| dtypes=onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), |
| ), |
| xfail( |
| "unbind", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "unflatten", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") |
| ), |
| xfail( |
| "uniform", |
| reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), |
| ), |
| xfail( |
| "unique", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), |
| ), |
| xfail( |
| "unique_consecutive", |
| reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), |
| ), |
| xfail( |
| "unravel_index", |
| dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, |
| reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), |
| ), |
| xfail( |
| "unsafe_split", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "unsafe_chunk", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), |
| ), |
| xfail( |
| "where", |
| dtypes=onnx_test_common.BOOL_TYPES, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), |
| ), |
| xfail( |
| "zeros", |
| dtypes=onnx_test_common.COMPLEX_TYPES, |
| reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." |
| ), |
| # SLOW TESTS (All are xfails if we run them) |
| # TODO: https://github.com/pytorch/pytorch/issues/117118 |
| skip_slow( |
| "cdist", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "histogram", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "histogramdd", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "linalg.lu_solve", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "linalg.solve_triangular", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "linalg.svd", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "logspace", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "logspace", |
| variant_name="tensor_overload", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "max_pool2d_with_indices_backward", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.interpolate", |
| variant_name="bicubic", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_unpool1d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_unpool2d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_unpool3d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_pool1d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_pool2d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.max_pool3d", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "nn.functional.unfold", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "ormqr", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "searchsorted", |
| reason="fixme: Test sets are too many.", |
| ), |
| skip_slow( |
| "svd", |
| reason="fixme: Test sets are too many.", |
| ), |
| ) |
| # fmt: on |
| |
| SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( |
| skip( |
| "_native_batch_norm_legit", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="https://github.com/pytorch/pytorch/issues/115106", |
| ), |
| xfail( |
| "addmm", # xfail can't only use dtypes to catch all cases |
| matcher=lambda sample: sample.input.dtype |
| in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Gemm", "uint8, int8, int16, int32, int64" |
| ), |
| ), |
| xfail( |
| "addmm", |
| matcher=lambda sample: sample.args[0].numel() == 0, |
| reason="ONNX Runtime does not support empty tensors multiplication", |
| ), |
| xfail( |
| "addmm", |
| variant_name="decomposed", |
| matcher=lambda sample: sample.args[0].numel() == 0, |
| reason="ONNX Runtime does not support empty tensors multiplication", |
| ), |
| xfail( |
| "amax", |
| matcher=lambda sample: len(sample.input.shape) == 0 |
| and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), |
| reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", |
| ), |
| xfail( |
| "amin", |
| matcher=lambda sample: len(sample.input.shape) == 0 |
| and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), |
| reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", |
| ), |
| xfail( |
| "aminmax", |
| matcher=lambda sample: len(sample.input.shape) == 0 |
| and sample.kwargs.get("dim") is not None, |
| reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", |
| ), |
| skip( |
| "cat", |
| matcher=lambda sample: sample.input[0].equal(torch.tensor([])), |
| reason="core dump - cat does not support zero-dim tensors yet", |
| ), |
| xfail( |
| "index_add", |
| matcher=lambda sample: len(sample.input.shape) < 2, |
| reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", |
| ), |
| xfail( |
| "index_add", |
| matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, |
| reason="fixme: aten::index_put indices contains None when dim is -1", |
| ), |
| xfail( |
| "index_copy", |
| matcher=lambda sample: len(sample.input.shape) < 2, |
| reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", |
| ), |
| xfail( |
| "index_copy", |
| matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, |
| reason="fixme: aten::index_put indices contains None when dim is -1", |
| ), |
| xfail( |
| "index_put", |
| matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) |
| and (sample.kwargs.get("accumulate") is False), |
| reason=onnx_test_common.reason_dynamo_does_not_support( |
| "https://github.com/pytorch/pytorch/issues/101150" |
| ), |
| ), |
| skip( |
| "linalg.multi_dot", |
| matcher=lambda sample: sum([torch.numel(input) for input in sample.input]) == 0, |
| reason="fixme: Undefined", |
| ), |
| skip( |
| "log_softmax", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: LogSoftMax does not support empty tensor as input", |
| ), |
| skip( |
| "log_softmax", |
| variant_name="with_dtype", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: LogSoftMax does not support empty tensor as input", |
| ), |
| xfail( |
| "logsumexp", |
| matcher=lambda sample: isinstance(sample.input, torch.Tensor) |
| and len(sample.input.shape) == 0, |
| reason="fixme: IsScalar", |
| ), |
| skip( |
| "masked.log_softmax", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: LogSoftMax does not support empty tensor as input", |
| ), |
| skip( |
| "matmul", |
| matcher=lambda sample: torch.numel(sample.input) == 0, |
| reason="values of matmul of [m, 0] and [0, n] matrices are undefined", |
| ), |
| xfail( |
| "min", |
| variant_name="reduction_with_dim", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: https://github.com/onnx/onnx/issues/4986", |
| ), |
| skip( |
| "mm", |
| matcher=lambda sample: torch.numel(sample.input) == 0, |
| reason="values of matmul of [m, 0] and [0, n] matrices are undefined", |
| ), |
| xfail( |
| "native_batch_norm", |
| matcher=lambda sample: sample.args[-3] is True |
| and any(arg is not None for arg in sample.args[2:4]), |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="https://github.com/pytorch/pytorch/issues/115106", |
| ), |
| xfail( |
| "nn.functional.avg_pool1d", |
| matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) |
| and ( |
| sample.kwargs.get("count_include_pad") is True |
| or sample.input.shape[2] |
| % ( |
| sample.args[0][0] |
| if isinstance(sample.args[0], tuple) |
| else sample.args[0] |
| ) |
| != 0 |
| ), |
| reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", |
| ), |
| xfail( |
| "nn.functional.avg_pool2d", |
| matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) |
| or (sample.kwargs.get("divisor_override") is not None), |
| reason="ONNX doesn't support divisor_override argument", |
| ), |
| xfail( |
| "nn.functional.avg_pool3d", |
| matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, |
| reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", |
| ), |
| xfail( |
| "nn.functional.avg_pool3d", |
| matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) |
| or (sample.kwargs.get("divisor_override") is not None), |
| reason="ONNX doesn't support divisor_override argument", |
| ), |
| xfail( |
| "nn.functional.batch_norm", |
| matcher=lambda sample: sample.kwargs.get("training") is True |
| and any(arg is not None for arg in sample.args[2:4]), |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", |
| ), |
| xfail( |
| "nn.functional.conv2d", |
| matcher=lambda sample: sample.kwargs.get("padding") == "valid", |
| reason="fixme: https://github.com/pytorch/pytorch/issues/117054", |
| ), |
| xfail( |
| "nn.functional.conv3d", |
| matcher=lambda sample: sample.kwargs.get("padding") == "valid", |
| reason="fixme: https://github.com/pytorch/pytorch/issues/117054", |
| ), |
| skip( |
| "nn.functional.cross_entropy", |
| matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), |
| reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", |
| ), |
| xfail( |
| "nn.functional.embedding", |
| matcher=lambda sample: sample.kwargs.get("max_norm") is not None, |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="https://github.com/pytorch/pytorch/issues/115106", |
| ), |
| skip_torchlib_forward_compatibility( |
| "nn.functional.embedding_bag", |
| matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, |
| reason=onnx_test_common.reason_onnx_script_does_not_support( |
| "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " |
| "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" |
| ), |
| github_issue="https://github.com/microsoft/onnxscript/issues/1056", |
| ), |
| xfail( |
| "nn.functional.group_norm", |
| matcher=lambda sample: torch.numel(sample.input) == 0, |
| reason=onnx_test_common.reason_onnx_runtime_does_not_support( |
| "Reshape", "empty tensor" |
| ), |
| ), |
| xfail( |
| "nn.functional.instance_norm", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| matcher=lambda sample: sample.kwargs.get("running_mean") is not None |
| or sample.input.dtype in (torch.float16,), |
| reason="fixme: KeyError: 'self___kwargs__running_mean'", |
| ), |
| xfail( |
| "nn.functional.max_pool3d", |
| matcher=lambda sample: sample.kwargs.get("ceil_mode") is True |
| and sample.kwargs.get("padding") == 1, |
| reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", |
| ), |
| xfail( |
| "nonzero", |
| matcher=lambda sample: len(sample.input.shape) == 0 |
| and sample.kwargs.get("as_tuple", False) is False, |
| reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", |
| model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, |
| ), |
| xfail( |
| "nonzero", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason=onnx_test_common.reason_onnx_script_does_not_support( |
| "aten::_assert_async.msg", |
| "https://github.com/pytorch/pytorch/issues/112443", |
| ), |
| ), |
| xfail( |
| "scatter_add", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", |
| ), |
| skip( |
| "scatter_reduce", |
| variant_name="amax", |
| # ONNX has not include_self parameter and default is include_self=True mode |
| matcher=lambda sample: sample.kwargs.get("include_self") is False, |
| reason="ONNX does't support include_self=False option", |
| ), |
| skip( |
| "scatter_reduce", |
| variant_name="amin", |
| # ONNX has not include_self parameter and default is include_self=True mode |
| matcher=lambda sample: sample.kwargs.get("include_self") is False, |
| reason="ONNX does't support include_self=False option", |
| ), |
| skip( |
| "scatter_reduce", |
| variant_name="prod", |
| # ONNX has not include_self parameter and default is include_self=True mode |
| matcher=lambda sample: sample.kwargs.get("include_self") is False, |
| reason="ONNX does't support include_self=False option", |
| ), |
| skip( |
| "scatter_reduce", |
| variant_name="sum", |
| # ONNX has not include_self parameter and default is include_self=True mode |
| matcher=lambda sample: sample.kwargs.get("include_self") is False, |
| reason="ONNX does't support include_self=False option", |
| ), |
| skip( |
| "softmax", |
| matcher=lambda sample: len(sample.input.shape) == 0, |
| reason="fixme: LogSoftMax does not support empty tensor as input", |
| ), |
| xfail( |
| "t", |
| matcher=lambda sample: isinstance(sample.input, torch.Tensor) |
| and len(sample.input.shape) < 2, |
| reason="fixme: IsScalar", |
| ), |
| xfail( |
| "unflatten", |
| reason="Logic not implemented for size 0 inputs in op.Reshape", |
| matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), |
| ), |
| skip( |
| "signal.windows.hamming", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| skip( |
| "signal.windows.general_hamming", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| skip( |
| "signal.windows.blackman", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| skip( |
| "signal.windows.general_cosine", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| skip( |
| "signal.windows.hann", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| skip( |
| "signal.windows.nuttall", |
| model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| reason="does not match node name", |
| ), |
| ) |
| |
| OPS_DB = copy.deepcopy(common_methods_invocations.op_db) |
| OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS) |
| ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) |
| |
| |
| def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: |
| return [d[i] for i in range(spec.num_children)] |
| |
| |
| torch.fx._pytree.register_pytree_flatten_spec( |
| torch.Size, |
| _torch_size_flatten_spec, |
| ) |
| |
| |
| class SingleOpModel(torch.nn.Module): |
| """Test model to wrap around a single op for export.""" |
| |
| def __init__(self, op, kwargs): |
| super().__init__() |
| self.operator = op |
| self.kwargs = kwargs |
| |
| def forward(self, *args): |
| return self.operator(*args, **self.kwargs) |
| |
| |
| def _should_skip_xfail_test_sample( |
| op_name: str, |
| variant_test_name: str, |
| sample, |
| model_type: pytorch_test_common.TorchModelType, |
| ) -> Tuple[Optional[str], Optional[str]]: |
| """Check if the test sample should be skipped or xfailed. |
| |
| If the xfail/skip decorator meta is matched with its op_name and model_type, |
| return the test_behavior and reason. Otherwise, return None, None. Note that |
| if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. |
| |
| Args: |
| op_name: The name of the op. |
| sample: The test sample. |
| model_type: The model type of the test. |
| |
| Returns: |
| A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". |
| reason is the reason for the test_behavior. |
| """ |
| |
| if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: |
| return None, None |
| for decorator_meta in SKIP_XFAIL_SUBTESTS: |
| # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small. |
| # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. |
| if ( |
| decorator_meta.op_name == op_name |
| and decorator_meta.variant_name == variant_test_name |
| ) and ( |
| model_type == decorator_meta.model_type or decorator_meta.model_type is None |
| ): |
| if decorator_meta.matcher is None and decorator_meta.model_type is None: |
| raise TypeError( |
| "Either Matcher or model_type must be defined in sub xfail and skip." |
| ) |
| if decorator_meta.matcher is not None and decorator_meta.matcher(sample): |
| return decorator_meta.test_behavior, decorator_meta.reason |
| elif decorator_meta.matcher is None: |
| # xfail/skip the whole test of the model type without matcher |
| return decorator_meta.test_behavior, decorator_meta.reason |
| return None, None |
| |
| |
| def _compare_onnx_and_torch_exported_program( |
| torch_exported_program, |
| onnx_exported_program, |
| input_args, |
| input_kwargs=None, |
| test_name=None, |
| sample_num=None, |
| sample_kwargs=None, |
| rtol=1e-03, |
| atol=1e-07, |
| only_check_shape=False, |
| ): |
| # avoid mutable default argument |
| if input_kwargs is None: |
| input_kwargs = {} |
| |
| # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. |
| # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. |
| # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() |
| onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) |
| if isinstance(torch_exported_program, torch.export.ExportedProgram): |
| torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) |
| else: |
| torch_outputs = torch_exported_program(*input_args, **input_kwargs) |
| torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( |
| torch_outputs |
| ) |
| if len(torch_outputs_onnx_format) != len(onnx_outputs): |
| raise AssertionError( |
| f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" |
| ) |
| |
| for j, (torch_output, onnx_output) in enumerate( |
| zip(torch_outputs_onnx_format, onnx_outputs) |
| ): |
| if only_check_shape: |
| assert torch_output.shape == onnx_output.shape |
| else: |
| try: |
| torch.testing.assert_close( |
| torch.tensor(onnx_output), |
| torch_output, |
| rtol=rtol, |
| atol=atol, |
| equal_nan=True, |
| ) |
| except AssertionError as e: |
| if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": |
| error_reproduction.create_mismatch_report( |
| test_name, |
| sample_num, |
| onnx_exported_program.model_proto, |
| input_args, |
| sample_kwargs, |
| torch.tensor(onnx_output), |
| torch_output, |
| e, |
| ) |
| if len(torch_outputs_onnx_format) > 1: |
| raise AssertionError(f"Output {j} mismatch") from e |
| raise |
| |
| |
| def _run_test_output_match( |
| test_suite: onnx_test_common._TestONNXRuntime, |
| device: str, |
| dtype: torch.dtype, |
| op: opinfo_core.OpInfo, |
| ): |
| # device is provided by instantiate_device_type_tests, but we only want to run in cpu. |
| assert device == "cpu" |
| samples = op.sample_inputs( |
| device, |
| dtype, |
| requires_grad=False, |
| ) |
| for i, cpu_sample in enumerate(samples): |
| inputs = (cpu_sample.input, *cpu_sample.args) |
| # Provide the repr to subtest because tensors are not serializable in parallel test runs |
| |
| with test_suite.subTest( |
| opset=test_suite.opset_version, |
| sample_num=i, |
| inputs=repr(inputs), |
| kwargs=repr(cpu_sample.kwargs), |
| ): |
| test_behavior, reason = _should_skip_xfail_test_sample( |
| op.name, op.variant_test_name, cpu_sample, test_suite.model_type |
| ) |
| with onnx_test_common.normal_xfail_skip_test_behaviors( |
| test_behavior, reason |
| ): |
| model = SingleOpModel(op.op, cpu_sample.kwargs) |
| model.eval() |
| |
| if ( |
| dtype == torch.float32 |
| and op.name in test_suite.fp32_low_precision_dict |
| ): |
| rtol = test_suite.fp32_low_precision_dict[op.name][0] |
| atol = test_suite.fp32_low_precision_dict[op.name][1] |
| elif dtype == torch.float32: |
| # Relax atol and rtol for float32 based on empirical results |
| rtol = 1e-5 |
| atol = 2e-5 |
| elif ( |
| dtype == torch.float16 |
| and (op.name, op.variant_test_name) |
| in test_suite.fp16_low_precision_variant_dict |
| ): |
| rtol = test_suite.fp16_low_precision_variant_dict[ |
| (op.name, op.variant_test_name) |
| ][0] |
| atol = test_suite.fp16_low_precision_variant_dict[ |
| (op.name, op.variant_test_name) |
| ][1] |
| elif ( |
| dtype == torch.float16 |
| and op.name in test_suite.fp16_low_precision_dict |
| ): |
| rtol = test_suite.fp16_low_precision_dict[op.name][0] |
| atol = test_suite.fp16_low_precision_dict[op.name][1] |
| else: |
| rtol = None |
| atol = None |
| |
| if ( |
| test_suite.model_type |
| == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM |
| ): |
| try: |
| model = torch.export.export(model, inputs) |
| except AssertionError as e: |
| # NOTE: avoid fake_mode detection bug in torch.export.export |
| pytest.xfail( |
| onnx_test_common.reason_dynamo_does_not_support(str(e)) |
| ) |
| |
| try: |
| onnx_program = torch.onnx.dynamo_export( |
| model, |
| *inputs, |
| ) |
| except torch.onnx.OnnxExporterError as e: |
| # NOTE: If the model has unsupported nodes, we will skip the test |
| # with non-strict xfail. Otherwise, we will raise the error. |
| if hasattr( |
| e.__cause__, "diagnostic" |
| ) and e.__cause__.diagnostic.rule in ( |
| _rules._POERules.no_symbolic_function_for_call_function, |
| _rules._POERules.unsupported_fx_node_analysis, |
| ): |
| pytest.xfail( |
| onnx_test_common.reason_onnx_script_does_not_support(str(e)) |
| ) |
| else: |
| raise e |
| _compare_onnx_and_torch_exported_program( |
| model, |
| onnx_program, |
| inputs, |
| test_name=test_suite.id(), |
| sample_num=i, |
| sample_kwargs=cpu_sample.kwargs, |
| rtol=rtol, |
| atol=atol, |
| only_check_shape=(op.name in test_suite.only_shape_check_list), |
| ) |
| |
| |
| def _parameterized_class_attrs_and_values(): |
| input_values = [] |
| input_values.extend( |
| itertools.product( |
| (opset for opset in onnx_test_common.FX_TESTED_OPSETS), |
| ( |
| pytorch_test_common.TorchModelType.TORCH_NN_MODULE, |
| pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, |
| ), |
| ) |
| ) |
| return { |
| "attrs": ["opset_version", "model_type"], |
| "input_values": input_values, |
| } |
| |
| |
| def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): |
| """Combine class name with the parameterized arguments. |
| |
| This function is passed to `parameterized.parameterized_class` as the |
| `class_name_func` argument. |
| """ |
| suffixes = [] |
| for k, v in input_dicts.items(): |
| suffixes.append(f"{k}_{v}") |
| return f"{cls.__name__}_{'_'.join(suffixes)}" |
| |
| |
| @parameterized.parameterized_class( |
| **_parameterized_class_attrs_and_values(), |
| class_name_func=_parameterize_class_name, |
| ) |
| class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): |
| """Test output consistency between exported ONNX models and PyTorch eager mode. |
| |
| This is a parameterized test suite. |
| """ |
| |
| opset_version = -1 |
| op_level_debug: bool = False |
| dynamic_shapes: bool = False |
| model_type: pytorch_test_common.TorchModelType = ( |
| pytorch_test_common.TorchModelType.TORCH_NN_MODULE |
| ) |
| |
| # NOTE: Follow torchlib settings in ops_test_data.py |
| only_shape_check_list = [ |
| "empty", |
| "empty_like", |
| "empty_strided", |
| "new_empty", |
| "new_empty_strided", |
| ] |
| |
| fp32_low_precision_dict = { |
| "native_layer_norm": [2e-4, 7e-4], |
| } |
| |
| fp16_low_precision_dict = { |
| "addbmm": [2e-1, 2e-2], |
| "addcdiv": [3e-2, 1e-3], |
| "addcmul": [3e-2, 1e-3], |
| "addmv": [5e-2, 3e-2], |
| "addr": [3e-3, 4e-3], |
| "baddbmm": [3e-2, 1e-3], |
| "cumulative_trapezoid": [3e-2, 1e-3], |
| "diff": [1e-2, 5e-2], |
| "gradient": [3e-3, 4e-3], |
| "linalg.multi_dot": [3e-2, 1e-3], |
| "linalg.vecdot": [1e-2, 2e-2], |
| "linspace": [2e-2, 2e-3], |
| "masked.std": [2e-2, 2e-3], |
| "masked.var": [2e-2, 2e-2], |
| "matmul": [2e-2, 6e-2], |
| "nn.functional.batch_norm": [3e-2, 1e-3], |
| "nn.functional.binary_cross_entropy": [3e-2, 1e-3], |
| "nn.functional.binary_cross_entropy_with_logits": [3e-2, 1e-3], |
| "nn.functional.cosine_similarity": [3e-2, 1e-3], |
| "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], |
| "nn.functional.hardsigmoid": [1e-3, 5e-3], |
| "nn.functional.hardswish": [1e-3, 5e-3], |
| "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], |
| "nn.functional.instance_norm": [1e-2, 1e-3], |
| "nn.functional.interpolate": [1e-2, 1e-3], |
| "nn.functional.kl_div": [2e-3, 2e-4], |
| "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], |
| "nn.functional.local_response_norm": [1e-2, 5e-3], |
| "nn.functional.poisson_nll_loss": [3e-2, 1e-3], |
| "native_batch_norm": [3e-2, 1e-3], |
| "dot": [3e-2, 1e-3], |
| "logit": [3e-2, 1e-3], |
| "rsub": [3e-2, 1e-3], |
| "sinc": [2e-1, 6e-4], |
| "sub": [3e-2, 1e-3], |
| "trapezoid": [1e-3, 7e-3], |
| "trapz": [1e-3, 7e-3], |
| } |
| |
| fp16_low_precision_variant_dict = { |
| ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], |
| ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], |
| } |
| |
| @common_device_type.ops( |
| [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], |
| allowed_dtypes=onnx_test_common.TESTED_DTYPES, |
| ) |
| def test_output_match(self, device: str, dtype: torch.dtype, op): |
| """Test the ONNX exporter.""" |
| _run_test_output_match(self, device, dtype, op) |
| |
| |
| for opset in onnx_test_common.FX_TESTED_OPSETS: |
| for model_type in pytorch_test_common.TorchModelType: |
| # The name needs to match the parameterized_class name. |
| test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" |
| onnx_test_common.add_decorate_info( |
| OPS_DB, |
| test_class_name, |
| "test_output_match", |
| opset=opset, |
| skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, |
| ) |
| |
| common_device_type.instantiate_device_type_tests( |
| globals()[test_class_name], globals(), only_for="cpu" |
| ) |
| |
| if __name__ == "__main__": |
| common_utils.run_tests() |