| # Owner(s): ["module: onnx"] | 
 |  | 
 | """Test consistency between the output values of torch.onnx FX exported operators | 
 | and torch operators given the same inputs. | 
 |  | 
 | Usage: | 
 |  | 
 |     1. Test all operators: | 
 |  | 
 |     pytest test/onnx/test_fx_op_consistency.py | 
 |  | 
 |     2. To run tests on a specific operator (e.g. torch.ceil): | 
 |  | 
 |     pytest test/onnx/test_fx_op_consistency.py -k ceil | 
 |     pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention | 
 |  | 
 |     3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. | 
 |  | 
 |     CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int | 
 |  | 
 |     NOTE: Read more on Running and writing tests: | 
 |         https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests | 
 |  | 
 | Note: | 
 |  | 
 |     1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. | 
 |  | 
 |     2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. | 
 |  | 
 |     3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and | 
 |     TESTED_OPS lists. See "Modify this section" | 
 |  | 
 | """ | 
 |  | 
 | from __future__ import annotations | 
 |  | 
 | import copy | 
 | import itertools | 
 | import os | 
 | from typing import ( | 
 |     Any, | 
 |     Callable, | 
 |     Collection, | 
 |     List, | 
 |     Mapping, | 
 |     Optional, | 
 |     Tuple, | 
 |     Type, | 
 |     Union, | 
 | ) | 
 |  | 
 | import error_reproduction | 
 |  | 
 | import onnx_test_common | 
 |  | 
 | import parameterized | 
 | import pytest | 
 | import pytorch_test_common | 
 |  | 
 | import torch | 
 | from onnx_test_common import skip, skip_slow, xfail | 
 | from torch.onnx._internal.diagnostics import _rules | 
 | from torch.testing._internal import ( | 
 |     common_device_type, | 
 |     common_methods_invocations, | 
 |     common_utils, | 
 | ) | 
 | from torch.testing._internal.opinfo import core as opinfo_core | 
 |  | 
 |  | 
 | # NOTE: For ATen signature modifications that will break ONNX export, | 
 | # use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip | 
 | # to make the signal apparent for maintainers. | 
 | def xfail_torchlib_forward_compatibility( | 
 |     op_name: str, | 
 |     variant_name: str = "", | 
 |     *, | 
 |     reason: str, | 
 |     github_issue: str, | 
 |     opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, | 
 |     dtypes: Optional[Collection[torch.dtype]] = None, | 
 |     matcher: Optional[Callable[[Any], bool]] = None, | 
 |     enabled_if: bool = True, | 
 | ): | 
 |     """Prefer using this (xfail) over skip when possible. | 
 |  | 
 |     Only skip when the test is not failing consistently. | 
 |     """ | 
 |     return xfail( | 
 |         op_name, | 
 |         variant_name=variant_name, | 
 |         reason=f"{reason}. GitHub Issue: {github_issue}", | 
 |         opsets=opsets, | 
 |         dtypes=dtypes, | 
 |         matcher=matcher, | 
 |         enabled_if=enabled_if, | 
 |     ) | 
 |  | 
 |  | 
 | def skip_torchlib_forward_compatibility( | 
 |     op_name: str, | 
 |     variant_name: str = "", | 
 |     *, | 
 |     reason: str, | 
 |     github_issue: str, | 
 |     opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, | 
 |     dtypes: Optional[Collection[torch.dtype]] = None, | 
 |     matcher: Optional[Callable[[Any], Any]] = None, | 
 |     enabled_if: bool = True, | 
 | ): | 
 |     """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. | 
 |  | 
 |     Only skip when the test is not failing consistently. | 
 |     """ | 
 |     return skip( | 
 |         op_name, | 
 |         variant_name=variant_name, | 
 |         reason=f"{reason}. GitHub Issue: {github_issue}", | 
 |         opsets=opsets, | 
 |         dtypes=dtypes, | 
 |         matcher=matcher, | 
 |         enabled_if=enabled_if, | 
 |     ) | 
 |  | 
 |  | 
 | # fmt: off | 
 | # Turn off black formatting to keep the list compact | 
 |  | 
 | # Expected failures for onnx export. | 
 | # The list should be sorted alphabetically by op name. | 
 | # Q: When should I use fixme vs vs skip vs xfail? | 
 | # A: Prefer xfail over skip when possible. | 
 | #     2a. If a test is now failing because of xpass, because some previous errors | 
 | #     are now fixed, removed the corresponding xfail. | 
 | #     2b. If a test is not failing consistently, use skip. | 
 | EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( | 
 |     xfail( | 
 |         "__getitem__", | 
 |         reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", | 
 |     ), | 
 |     xfail( | 
 |         "__radd__", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "__rmatmul__", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "__rpow__", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), | 
 |     ), | 
 |     skip( | 
 |         "_native_batch_norm_legit", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: Assertion error: result mismatch and type error", | 
 |     ), | 
 |     xfail( | 
 |         "_softmax_backward_data", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)") | 
 |     ), | 
 |     xfail( | 
 |         "add", dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Add") | 
 |     ), | 
 |     xfail( | 
 |         "add", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support( | 
 |             "Add", "int8, int16, uint8 have type issue." | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "addbmm", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "addmm", dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Addmm") | 
 |     ), | 
 |     xfail( | 
 |         "addmm", | 
 |         variant_name="decomposed", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Addmm") | 
 |     ), | 
 |     skip( | 
 |         "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") | 
 |     ), | 
 |     skip( | 
 |         "addmm", | 
 |         variant_name="decomposed", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") | 
 |     ), | 
 |     xfail( | 
 |         "addr", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support( | 
 |             "Addr", "bool" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "addr", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "all", | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") | 
 |     ), | 
 |     xfail( | 
 |         "allclose", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") | 
 |     ), | 
 |     xfail( | 
 |         "amax", | 
 |         dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") | 
 |     ), | 
 |     xfail( | 
 |         "aminmax", | 
 |         dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "any", | 
 |         reason=onnx_test_common.reason_onnx_does_not_support( | 
 |             "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") | 
 |     ), | 
 |     xfail( | 
 |         "arange", | 
 |         dtypes=(torch.uint8,), | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), | 
 |     ), | 
 |     xfail( | 
 |         "arange", | 
 |         dtypes=(torch.int16, torch.int32), | 
 |         reason="AssertionError: The values for attribute 'shape' do not match", | 
 |     ), | 
 |     xfail( | 
 |         "argmax", | 
 |         dtypes=( | 
 |             torch.int16, | 
 |             torch.int64, | 
 |         ), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "ArgMax", "int16, int64" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "argmin", | 
 |         dtypes=( | 
 |             torch.uint8, | 
 |             torch.int8, | 
 |             torch.int16, | 
 |             torch.int64, | 
 |         ), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "ArgMin", "uint8, int8, int16, int64" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "argwhere", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     skip( | 
 |         "as_strided", | 
 |         variant_name="partial_views", | 
 |         reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", | 
 |     ), | 
 |     xfail( | 
 |         "atan2", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "baddbmm", | 
 |         dtypes=( | 
 |             torch.uint8, | 
 |             torch.int8, | 
 |             torch.int16, | 
 |         ), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Matmul", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "baddbmm", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "bernoulli", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "bfloat16", | 
 |         reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", | 
 |     ), | 
 |     xfail( | 
 |         "bincount", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), | 
 |     ), | 
 |     xfail( | 
 |         "bmm", | 
 |         dtypes=( | 
 |             torch.uint8, | 
 |             torch.int8, | 
 |             torch.int16, | 
 |         ), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Matmul", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "broadcast_shapes", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), | 
 |     ), | 
 |     xfail( | 
 |         "cauchy", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     skip( | 
 |         "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") | 
 |     ), | 
 |     xfail( | 
 |         "chalf", | 
 |         reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." | 
 |     ), | 
 |     xfail( | 
 |         "chunk", dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool") | 
 |     ), | 
 |     xfail( | 
 |         "chunk", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Chunk", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "clamp", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Max", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") | 
 |     ), | 
 |     xfail( | 
 |         "clamp_max", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Max", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "clamp_min", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Max", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") | 
 |     ), | 
 |     xfail( | 
 |         "constant_pad_nd", | 
 |         dtypes=(torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Constant_pad_nd", "int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "constant_pad_nd", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support( | 
 |             "Constant_pad_nd", "complex64" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "corrcoef", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support( | 
 |             "aten.equal.default" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "cov", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support( | 
 |             "aten.equal.default" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") | 
 |     ), | 
 |     xfail( | 
 |         "combinations", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), | 
 |     ), | 
 |     xfail( | 
 |         "cross", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"), | 
 |     ), | 
 |     xfail( | 
 |         "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") | 
 |     ), | 
 |     skip( | 
 |         "dot", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), | 
 |     ), | 
 |     xfail( | 
 |         "empty", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." | 
 |     ), | 
 |     xfail( | 
 |         "empty_strided", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "eq", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "equal", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") | 
 |     ), | 
 |     xfail( | 
 |         "exponential", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), | 
 |     ), | 
 |     xfail( | 
 |         "fft.fft", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.fft2", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.fftn", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.ifft", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.ifft2", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.ifftn", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.irfft", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.irfft2", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "fft.irfftn", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), | 
 |     ), | 
 |     xfail( | 
 |         "fft.rfft", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), | 
 |     ), | 
 |     xfail( | 
 |         "fft.rfftn", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), | 
 |     ), | 
 |     xfail( | 
 |         "fft.rfft2", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), | 
 |     ), | 
 |     xfail( | 
 |         "floor", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), | 
 |     ), | 
 |     xfail( | 
 |         "floor_divide", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), | 
 |     ), | 
 |     xfail( | 
 |         "full", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "full_like", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "geometric", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "heaviside", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "index_fill", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "index_put", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "index_put", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16,), | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "isnan", | 
 |         dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), | 
 |     ), | 
 |     xfail( | 
 |         "istft", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), | 
 |     ), | 
 |     xfail( | 
 |         "item", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "lerp", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "linalg.lstsq", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), | 
 |     ), | 
 |     xfail( | 
 |         "linalg.lstsq", | 
 |         variant_name="grad_oriented", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), | 
 |     ), | 
 |     xfail( | 
 |         "linalg.norm", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "linalg.norm", | 
 |         variant_name="subgradients_at_zero", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "linalg.vecdot", | 
 |         reason="fixme: Assertion error: result shape mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "linspace", | 
 |         dtypes=(torch.int64, torch.int32,), | 
 |         reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", | 
 |     ), | 
 |     xfail( | 
 |         "linspace", | 
 |         variant_name="tensor_overload", | 
 |         dtypes=(torch.int64, torch.int32,), | 
 |         reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", | 
 |     ), | 
 |     xfail( | 
 |         "linspace", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "linspace", | 
 |         variant_name="tensor_overload", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "log_normal", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "log_softmax", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "log_softmax", | 
 |         variant_name="with_dtype", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "logcumsumexp", | 
 |         reason=onnx_test_common.reason_onnx_does_not_support( | 
 |             "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") | 
 |     ), | 
 |     xfail( | 
 |         "logical_and", | 
 |         dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), | 
 |     ), | 
 |     xfail( | 
 |         "logical_not", | 
 |         dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), | 
 |     ), | 
 |     xfail( | 
 |         "logical_or", | 
 |         dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), | 
 |     ), | 
 |     xfail( | 
 |         "logical_xor", | 
 |         dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), | 
 |     ), | 
 |     xfail( | 
 |         "logsumexp", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"), | 
 |     ), | 
 |     xfail( | 
 |         "masked.logsumexp", | 
 |         reason="fixme: https://github.com/onnx/onnx/issues/4986", | 
 |     ), | 
 |     xfail( | 
 |         "masked.amax", | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "masked.amin", | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "masked.argmin", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "masked.argmax", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "masked_fill", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "masked.sum", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "masked.log_softmax", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "masked.mean", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "masked.norm", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "masked.prod", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "masked_select", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), | 
 |     ), | 
 |     xfail( | 
 |         "max", | 
 |         variant_name="reduction_no_dim", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "max", | 
 |         variant_name="reduction_with_dim", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "max", | 
 |         variant_name="reduction_with_dim", | 
 |         reason="https://github.com/onnx/onnx/issues/4986", | 
 |     ), | 
 |     xfail( | 
 |         "mean", | 
 |         reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", | 
 |     ), | 
 |     xfail( | 
 |         "min", | 
 |         variant_name="reduction_no_dim", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "min", | 
 |         variant_name="reduction_with_dim", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), | 
 |     ), | 
 |     skip( | 
 |         "mm", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), | 
 |     ), | 
 |     xfail( | 
 |         "multinomial", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nanquantile", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") | 
 |     ), | 
 |     xfail( | 
 |         "nansum", | 
 |         dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), | 
 |     ), | 
 |     xfail( | 
 |         "narrow", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), | 
 |     ), | 
 |     xfail( | 
 |         "native_batch_norm", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: https://github.com/microsoft/onnxscript/issues/1269", | 
 |     ), | 
 |     xfail( | 
 |         "native_layer_norm", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "new_full", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.adaptive_avg_pool2d", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ | 
 |             maximum recursion depth exceeded while calling a Python object"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.adaptive_avg_pool3d", | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.alpha_dropout", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool1d", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool2d", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool3d", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.batch_norm", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv_transpose1d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv_transpose2d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv_transpose3d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), | 
 |     ), | 
 |     skip( | 
 |         "nn.functional.conv_transpose1d", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     skip( | 
 |         "nn.functional.conv_transpose2d", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     skip( | 
 |         "nn.functional.conv_transpose3d", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv1d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv2d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv2d", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv3d", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv3d", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.ctc_loss", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.dropout", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.dropout2d", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.dropout3d", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.feature_alpha_dropout", | 
 |         variant_name="with_train", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.feature_alpha_dropout", | 
 |         variant_name="without_train", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.fractional_max_pool2d", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.fractional_max_pool3d", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.gaussian_nll_loss", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.grid_sample", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.group_norm", | 
 |         dtypes=(torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.local_response_norm", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.linear", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.max_pool2d", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.max_pool3d", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.multi_head_attention_forward", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.one_hot", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.pad", | 
 |         variant_name="replicate", | 
 |         reason="fixme: ORT error: padding size", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.pad", | 
 |         variant_name="replicate_negative", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.pad", | 
 |         variant_name="reflect", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.rrelu", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.rrelu", | 
 |         dtypes=(torch.int64,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), | 
 |     ), | 
 |     skip( | 
 |         "nn.functional.scaled_dot_product_attention", | 
 |         matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, | 
 |         reason="dropout is random so the results do not match", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.scaled_dot_product_attention", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.selu", | 
 |         reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.soft_margin_loss", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.tanhshrink", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "nonzero", | 
 |         dtypes=(torch.int8, torch.int16), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "normal", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "normal", | 
 |         variant_name="in_place", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "normal", | 
 |         variant_name="number_mean", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "ones", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." | 
 |     ), | 
 |     xfail( | 
 |         "pca_lowrank", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "quantile", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") | 
 |     ), | 
 |     xfail( | 
 |         "rand_like", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "randint", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "randint_like", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "randn", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "randn_like", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "resize_", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") | 
 |     ), | 
 |     xfail( | 
 |         "resize_as_", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") | 
 |     ), | 
 |     xfail( | 
 |         "round", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "rsub", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Mul", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_add", | 
 |         dtypes=(torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_reduce", | 
 |         variant_name="sum", | 
 |         dtypes=(torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_reduce", | 
 |         variant_name="prod", | 
 |         dtypes=(torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_reduce", | 
 |         variant_name="amin", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_reduce", | 
 |         variant_name="amax", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_reduce", | 
 |         variant_name="mean", | 
 |         reason="ONNX doesn't support reduce='mean' option", | 
 |     ), | 
 |     xfail( | 
 |         "sign", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "signal.windows.kaiser", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), | 
 |     ), | 
 |     xfail( | 
 |         "softmax", | 
 |         dtypes=(torch.float16,), | 
 |         reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" | 
 |     ), | 
 |     xfail( | 
 |         "sparse.mm", | 
 |         variant_name="reduce", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), | 
 |     ), | 
 |     xfail( | 
 |         "sparse.sampled_addmm", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), | 
 |     ), | 
 |     xfail( | 
 |         "special.erfcx", | 
 |         dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), | 
 |     ), | 
 |     xfail( | 
 |         "special.erfcx", | 
 |         dtypes=onnx_test_common.FLOAT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), | 
 |     ), | 
 |     xfail( | 
 |         "special.ndtr", | 
 |         dtypes=(torch.float16,), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "split", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "split", | 
 |         variant_name="list_args", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "split_with_sizes", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "square", | 
 |         dtypes=(torch.int8, torch.uint8, torch.int16), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), | 
 |     ), | 
 |     xfail( | 
 |         "squeeze", | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "squeeze", | 
 |         variant_name="multiple", | 
 |         reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", | 
 |     ), | 
 |     xfail( | 
 |         "svd_lowrank", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "std_mean", | 
 |         reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." | 
 |     ), | 
 |     xfail( | 
 |         "std_mean", | 
 |         variant_name="unbiased", | 
 |         reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." | 
 |     ), | 
 |     xfail( | 
 |         "stft", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), | 
 |     ), | 
 |     xfail( | 
 |         "sub", | 
 |         dtypes=(torch.uint8, torch.int8, torch.int16), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Mul", "uint8, int8, int16" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "take", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), | 
 |     ), | 
 |     xfail( | 
 |         "tensor_split", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), | 
 |     ), | 
 |     xfail( | 
 |         "to", | 
 |         dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.bool, torch.complex64), | 
 |         # model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, | 
 |         reason="This op requires torch.dtype as input, which is not supported currently.", | 
 |     ), | 
 |     xfail( | 
 |         "topk", | 
 |         dtypes=(torch.int64, torch.int32), | 
 |         reason="fixme: Assertion error: result mismatch", | 
 |     ), | 
 |     xfail( | 
 |         "tril", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), | 
 |     ), | 
 |     xfail( | 
 |         "triu", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), | 
 |     ), | 
 |     xfail( | 
 |         "trunc", | 
 |         dtypes=onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), | 
 |     ), | 
 |     xfail( | 
 |         "unbind", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "unflatten", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") | 
 |     ), | 
 |     xfail( | 
 |         "uniform", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), | 
 |     ), | 
 |     xfail( | 
 |         "unique", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), | 
 |     ), | 
 |     xfail( | 
 |         "unique_consecutive", | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), | 
 |     ), | 
 |     xfail( | 
 |         "unravel_index", | 
 |         dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), | 
 |     ), | 
 |     xfail( | 
 |         "unsafe_split", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "unsafe_chunk", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "where", | 
 |         dtypes=onnx_test_common.BOOL_TYPES, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), | 
 |     ), | 
 |     xfail( | 
 |         "zeros", | 
 |         dtypes=onnx_test_common.COMPLEX_TYPES, | 
 |         reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." | 
 |     ), | 
 |     # SLOW TESTS (All are xfails if we run them) | 
 |     # TODO: https://github.com/pytorch/pytorch/issues/117118 | 
 |     skip_slow( | 
 |         "cdist", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "histogram", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "histogramdd", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "linalg.lu_solve", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "linalg.solve_triangular", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "linalg.svd", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "logspace", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "logspace", | 
 |         variant_name="tensor_overload", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "max_pool2d_with_indices_backward", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.interpolate", | 
 |         variant_name="bicubic", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_unpool1d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_unpool2d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_unpool3d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_pool1d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_pool2d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.max_pool3d", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "nn.functional.unfold", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "ormqr", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "searchsorted", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 |     skip_slow( | 
 |         "svd", | 
 |         reason="fixme: Test sets are too many.", | 
 |     ), | 
 | ) | 
 | # fmt: on | 
 |  | 
 | SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( | 
 |     skip( | 
 |         "_native_batch_norm_legit", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="https://github.com/pytorch/pytorch/issues/115106", | 
 |     ), | 
 |     xfail( | 
 |         "addmm",  # xfail can't only use dtypes to catch all cases | 
 |         matcher=lambda sample: sample.input.dtype | 
 |         in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Gemm", "uint8, int8, int16, int32, int64" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "addmm", | 
 |         matcher=lambda sample: sample.args[0].numel() == 0, | 
 |         reason="ONNX Runtime does not support empty tensors multiplication", | 
 |     ), | 
 |     xfail( | 
 |         "addmm", | 
 |         variant_name="decomposed", | 
 |         matcher=lambda sample: sample.args[0].numel() == 0, | 
 |         reason="ONNX Runtime does not support empty tensors multiplication", | 
 |     ), | 
 |     xfail( | 
 |         "amax", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0 | 
 |         and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), | 
 |         reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", | 
 |     ), | 
 |     xfail( | 
 |         "amin", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0 | 
 |         and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), | 
 |         reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", | 
 |     ), | 
 |     xfail( | 
 |         "aminmax", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0 | 
 |         and sample.kwargs.get("dim") is not None, | 
 |         reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", | 
 |     ), | 
 |     skip( | 
 |         "cat", | 
 |         matcher=lambda sample: sample.input[0].equal(torch.tensor([])), | 
 |         reason="core dump - cat does not support zero-dim tensors yet", | 
 |     ), | 
 |     xfail( | 
 |         "index_add", | 
 |         matcher=lambda sample: len(sample.input.shape) < 2, | 
 |         reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", | 
 |     ), | 
 |     xfail( | 
 |         "index_add", | 
 |         matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, | 
 |         reason="fixme: aten::index_put indices contains None when dim is -1", | 
 |     ), | 
 |     xfail( | 
 |         "index_copy", | 
 |         matcher=lambda sample: len(sample.input.shape) < 2, | 
 |         reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", | 
 |     ), | 
 |     xfail( | 
 |         "index_copy", | 
 |         matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, | 
 |         reason="fixme: aten::index_put indices contains None when dim is -1", | 
 |     ), | 
 |     xfail( | 
 |         "index_put", | 
 |         matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) | 
 |         and (sample.kwargs.get("accumulate") is False), | 
 |         reason=onnx_test_common.reason_dynamo_does_not_support( | 
 |             "https://github.com/pytorch/pytorch/issues/101150" | 
 |         ), | 
 |     ), | 
 |     skip( | 
 |         "linalg.multi_dot", | 
 |         matcher=lambda sample: sum([torch.numel(input) for input in sample.input]) == 0, | 
 |         reason="fixme: Undefined", | 
 |     ), | 
 |     skip( | 
 |         "log_softmax", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: LogSoftMax does not support empty tensor as input", | 
 |     ), | 
 |     skip( | 
 |         "log_softmax", | 
 |         variant_name="with_dtype", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: LogSoftMax does not support empty tensor as input", | 
 |     ), | 
 |     xfail( | 
 |         "logsumexp", | 
 |         matcher=lambda sample: isinstance(sample.input, torch.Tensor) | 
 |         and len(sample.input.shape) == 0, | 
 |         reason="fixme: IsScalar", | 
 |     ), | 
 |     skip( | 
 |         "masked.log_softmax", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: LogSoftMax does not support empty tensor as input", | 
 |     ), | 
 |     skip( | 
 |         "matmul", | 
 |         matcher=lambda sample: torch.numel(sample.input) == 0, | 
 |         reason="values of matmul of [m, 0] and [0, n] matrices are undefined", | 
 |     ), | 
 |     xfail( | 
 |         "min", | 
 |         variant_name="reduction_with_dim", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: https://github.com/onnx/onnx/issues/4986", | 
 |     ), | 
 |     skip( | 
 |         "mm", | 
 |         matcher=lambda sample: torch.numel(sample.input) == 0, | 
 |         reason="values of matmul of [m, 0] and [0, n] matrices are undefined", | 
 |     ), | 
 |     xfail( | 
 |         "native_batch_norm", | 
 |         matcher=lambda sample: sample.args[-3] is True | 
 |         and any(arg is not None for arg in sample.args[2:4]), | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="https://github.com/pytorch/pytorch/issues/115106", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool1d", | 
 |         matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) | 
 |         and ( | 
 |             sample.kwargs.get("count_include_pad") is True | 
 |             or sample.input.shape[2] | 
 |             % ( | 
 |                 sample.args[0][0] | 
 |                 if isinstance(sample.args[0], tuple) | 
 |                 else sample.args[0] | 
 |             ) | 
 |             != 0 | 
 |         ), | 
 |         reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool2d", | 
 |         matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) | 
 |         or (sample.kwargs.get("divisor_override") is not None), | 
 |         reason="ONNX doesn't support divisor_override argument", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool3d", | 
 |         matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, | 
 |         reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.avg_pool3d", | 
 |         matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) | 
 |         or (sample.kwargs.get("divisor_override") is not None), | 
 |         reason="ONNX doesn't support divisor_override argument", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.batch_norm", | 
 |         matcher=lambda sample: sample.kwargs.get("training") is True | 
 |         and any(arg is not None for arg in sample.args[2:4]), | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv2d", | 
 |         matcher=lambda sample: sample.kwargs.get("padding") == "valid", | 
 |         reason="fixme: https://github.com/pytorch/pytorch/issues/117054", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.conv3d", | 
 |         matcher=lambda sample: sample.kwargs.get("padding") == "valid", | 
 |         reason="fixme: https://github.com/pytorch/pytorch/issues/117054", | 
 |     ), | 
 |     skip( | 
 |         "nn.functional.cross_entropy", | 
 |         matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), | 
 |         reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.embedding", | 
 |         matcher=lambda sample: sample.kwargs.get("max_norm") is not None, | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="https://github.com/pytorch/pytorch/issues/115106", | 
 |     ), | 
 |     skip_torchlib_forward_compatibility( | 
 |         "nn.functional.embedding_bag", | 
 |         matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support( | 
 |             "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " | 
 |             "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" | 
 |         ), | 
 |         github_issue="https://github.com/microsoft/onnxscript/issues/1056", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.group_norm", | 
 |         matcher=lambda sample: torch.numel(sample.input) == 0, | 
 |         reason=onnx_test_common.reason_onnx_runtime_does_not_support( | 
 |             "Reshape", "empty tensor" | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.instance_norm", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         matcher=lambda sample: sample.kwargs.get("running_mean") is not None | 
 |         or sample.input.dtype in (torch.float16,), | 
 |         reason="fixme: KeyError: 'self___kwargs__running_mean'", | 
 |     ), | 
 |     xfail( | 
 |         "nn.functional.max_pool3d", | 
 |         matcher=lambda sample: sample.kwargs.get("ceil_mode") is True | 
 |         and sample.kwargs.get("padding") == 1, | 
 |         reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", | 
 |     ), | 
 |     xfail( | 
 |         "nonzero", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0 | 
 |         and sample.kwargs.get("as_tuple", False) is False, | 
 |         reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, | 
 |     ), | 
 |     xfail( | 
 |         "nonzero", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason=onnx_test_common.reason_onnx_script_does_not_support( | 
 |             "aten::_assert_async.msg", | 
 |             "https://github.com/pytorch/pytorch/issues/112443", | 
 |         ), | 
 |     ), | 
 |     xfail( | 
 |         "scatter_add", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", | 
 |     ), | 
 |     skip( | 
 |         "scatter_reduce", | 
 |         variant_name="amax", | 
 |         # ONNX has not include_self parameter and default is include_self=True mode | 
 |         matcher=lambda sample: sample.kwargs.get("include_self") is False, | 
 |         reason="ONNX does't support include_self=False option", | 
 |     ), | 
 |     skip( | 
 |         "scatter_reduce", | 
 |         variant_name="amin", | 
 |         # ONNX has not include_self parameter and default is include_self=True mode | 
 |         matcher=lambda sample: sample.kwargs.get("include_self") is False, | 
 |         reason="ONNX does't support include_self=False option", | 
 |     ), | 
 |     skip( | 
 |         "scatter_reduce", | 
 |         variant_name="prod", | 
 |         # ONNX has not include_self parameter and default is include_self=True mode | 
 |         matcher=lambda sample: sample.kwargs.get("include_self") is False, | 
 |         reason="ONNX does't support include_self=False option", | 
 |     ), | 
 |     skip( | 
 |         "scatter_reduce", | 
 |         variant_name="sum", | 
 |         # ONNX has not include_self parameter and default is include_self=True mode | 
 |         matcher=lambda sample: sample.kwargs.get("include_self") is False, | 
 |         reason="ONNX does't support include_self=False option", | 
 |     ), | 
 |     skip( | 
 |         "softmax", | 
 |         matcher=lambda sample: len(sample.input.shape) == 0, | 
 |         reason="fixme: LogSoftMax does not support empty tensor as input", | 
 |     ), | 
 |     xfail( | 
 |         "t", | 
 |         matcher=lambda sample: isinstance(sample.input, torch.Tensor) | 
 |         and len(sample.input.shape) < 2, | 
 |         reason="fixme: IsScalar", | 
 |     ), | 
 |     xfail( | 
 |         "unflatten", | 
 |         reason="Logic not implemented for size 0 inputs in op.Reshape", | 
 |         matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.hamming", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.general_hamming", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.blackman", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.general_cosine", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.hann", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 |     skip( | 
 |         "signal.windows.nuttall", | 
 |         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |         reason="does not match node name", | 
 |     ), | 
 | ) | 
 |  | 
 | OPS_DB = copy.deepcopy(common_methods_invocations.op_db) | 
 | OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS) | 
 | ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) | 
 |  | 
 |  | 
 | def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: | 
 |     return [d[i] for i in range(spec.num_children)] | 
 |  | 
 | torch.fx._pytree.register_pytree_flatten_spec( | 
 |     torch.Size, | 
 |     _torch_size_flatten_spec, | 
 | ) | 
 |  | 
 |  | 
 | class SingleOpModel(torch.nn.Module): | 
 |     """Test model to wrap around a single op for export.""" | 
 |  | 
 |     def __init__(self, op, kwargs): | 
 |         super().__init__() | 
 |         self.operator = op | 
 |         self.kwargs = kwargs | 
 |  | 
 |     def forward(self, *args): | 
 |         return self.operator(*args, **self.kwargs) | 
 |  | 
 |  | 
 | def _should_skip_xfail_test_sample( | 
 |     op_name: str, | 
 |     variant_test_name: str, | 
 |     sample, | 
 |     model_type: pytorch_test_common.TorchModelType, | 
 | ) -> Tuple[Optional[str], Optional[str]]: | 
 |     """Check if the test sample should be skipped or xfailed. | 
 |  | 
 |     If the xfail/skip decorator meta is matched with its op_name and model_type, | 
 |     return the test_behavior and reason. Otherwise, return None, None. Note that | 
 |     if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. | 
 |  | 
 |     Args: | 
 |         op_name: The name of the op. | 
 |         sample: The test sample. | 
 |         model_type: The model type of the test. | 
 |  | 
 |     Returns: | 
 |         A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". | 
 |         reason is the reason for the test_behavior. | 
 |     """ | 
 |  | 
 |     if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: | 
 |         return None, None | 
 |     for decorator_meta in SKIP_XFAIL_SUBTESTS: | 
 |         # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small. | 
 |         # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. | 
 |         if ( | 
 |             decorator_meta.op_name == op_name | 
 |             and decorator_meta.variant_name == variant_test_name | 
 |         ) and ( | 
 |             model_type == decorator_meta.model_type or decorator_meta.model_type is None | 
 |         ): | 
 |             if decorator_meta.matcher is None and decorator_meta.model_type is None: | 
 |                 raise TypeError( | 
 |                     "Either Matcher or model_type must be defined in sub xfail and skip." | 
 |                 ) | 
 |             if decorator_meta.matcher is not None and decorator_meta.matcher(sample): | 
 |                 return decorator_meta.test_behavior, decorator_meta.reason | 
 |             elif decorator_meta.matcher is None: | 
 |                 # xfail/skip the whole test of the model type without matcher | 
 |                 return decorator_meta.test_behavior, decorator_meta.reason | 
 |     return None, None | 
 |  | 
 |  | 
 | def _compare_onnx_and_torch_exported_program( | 
 |     torch_exported_program, | 
 |     onnx_exported_program, | 
 |     input_args, | 
 |     input_kwargs=None, | 
 |     test_name=None, | 
 |     sample_num=None, | 
 |     sample_kwargs=None, | 
 |     rtol=1e-03, | 
 |     atol=1e-07, | 
 |     only_check_shape=False, | 
 | ): | 
 |     # avoid mutable default argument | 
 |     if input_kwargs is None: | 
 |         input_kwargs = {} | 
 |  | 
 |     # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. | 
 |     # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. | 
 |     # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() | 
 |     onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) | 
 |     if isinstance(torch_exported_program, torch.export.ExportedProgram): | 
 |         torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) | 
 |     else: | 
 |         torch_outputs = torch_exported_program(*input_args, **input_kwargs) | 
 |     torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( | 
 |         torch_outputs | 
 |     ) | 
 |     if len(torch_outputs_onnx_format) != len(onnx_outputs): | 
 |         raise AssertionError( | 
 |             f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" | 
 |         ) | 
 |  | 
 |     for j, (torch_output, onnx_output) in enumerate( | 
 |         zip(torch_outputs_onnx_format, onnx_outputs) | 
 |     ): | 
 |         if only_check_shape: | 
 |             assert torch_output.shape == onnx_output.shape | 
 |         else: | 
 |             try: | 
 |                 torch.testing.assert_close( | 
 |                     torch.tensor(onnx_output), | 
 |                     torch_output, | 
 |                     rtol=rtol, | 
 |                     atol=atol, | 
 |                     equal_nan=True, | 
 |                 ) | 
 |             except AssertionError as e: | 
 |                 if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": | 
 |                     error_reproduction.create_mismatch_report( | 
 |                         test_name, | 
 |                         sample_num, | 
 |                         onnx_exported_program.model_proto, | 
 |                         input_args, | 
 |                         sample_kwargs, | 
 |                         torch.tensor(onnx_output), | 
 |                         torch_output, | 
 |                         e, | 
 |                     ) | 
 |                 if len(torch_outputs_onnx_format) > 1: | 
 |                     raise AssertionError(f"Output {j} mismatch") from e | 
 |                 raise | 
 |  | 
 |  | 
 | def _run_test_output_match( | 
 |     test_suite: onnx_test_common._TestONNXRuntime, | 
 |     device: str, | 
 |     dtype: torch.dtype, | 
 |     op: opinfo_core.OpInfo, | 
 | ): | 
 |     # device is provided by instantiate_device_type_tests, but we only want to run in cpu. | 
 |     assert device == "cpu" | 
 |     samples = op.sample_inputs( | 
 |         device, | 
 |         dtype, | 
 |         requires_grad=False, | 
 |     ) | 
 |     for i, cpu_sample in enumerate(samples): | 
 |         inputs = (cpu_sample.input, *cpu_sample.args) | 
 |         # Provide the repr to subtest because tensors are not serializable in parallel test runs | 
 |  | 
 |         with test_suite.subTest( | 
 |             opset=test_suite.opset_version, | 
 |             sample_num=i, | 
 |             inputs=repr(inputs), | 
 |             kwargs=repr(cpu_sample.kwargs), | 
 |         ): | 
 |             test_behavior, reason = _should_skip_xfail_test_sample( | 
 |                 op.name, op.variant_test_name, cpu_sample, test_suite.model_type | 
 |             ) | 
 |             with onnx_test_common.normal_xfail_skip_test_behaviors( | 
 |                 test_behavior, reason | 
 |             ): | 
 |                 model = SingleOpModel(op.op, cpu_sample.kwargs) | 
 |                 model.eval() | 
 |  | 
 |                 if ( | 
 |                     dtype == torch.float32 | 
 |                     and op.name in test_suite.fp32_low_precision_dict | 
 |                 ): | 
 |                     rtol = test_suite.fp32_low_precision_dict[op.name][0] | 
 |                     atol = test_suite.fp32_low_precision_dict[op.name][1] | 
 |                 elif dtype == torch.float32: | 
 |                     # Relax atol and rtol for float32 based on empirical results | 
 |                     rtol = 1e-5 | 
 |                     atol = 2e-5 | 
 |                 elif ( | 
 |                     dtype == torch.float16 | 
 |                     and (op.name, op.variant_test_name) | 
 |                     in test_suite.fp16_low_precision_variant_dict | 
 |                 ): | 
 |                     rtol = test_suite.fp16_low_precision_variant_dict[ | 
 |                         (op.name, op.variant_test_name) | 
 |                     ][0] | 
 |                     atol = test_suite.fp16_low_precision_variant_dict[ | 
 |                         (op.name, op.variant_test_name) | 
 |                     ][1] | 
 |                 elif ( | 
 |                     dtype == torch.float16 | 
 |                     and op.name in test_suite.fp16_low_precision_dict | 
 |                 ): | 
 |                     rtol = test_suite.fp16_low_precision_dict[op.name][0] | 
 |                     atol = test_suite.fp16_low_precision_dict[op.name][1] | 
 |                 else: | 
 |                     rtol = None | 
 |                     atol = None | 
 |  | 
 |                 if ( | 
 |                     test_suite.model_type | 
 |                     == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM | 
 |                 ): | 
 |                     try: | 
 |                         model = torch.export.export(model, inputs) | 
 |                     except AssertionError as e: | 
 |                         # NOTE: avoid fake_mode detection bug in torch.export.export | 
 |                         pytest.xfail( | 
 |                             onnx_test_common.reason_dynamo_does_not_support(str(e)) | 
 |                         ) | 
 |  | 
 |                 try: | 
 |                     onnx_program = torch.onnx.dynamo_export( | 
 |                         model, | 
 |                         *inputs, | 
 |                     ) | 
 |                 except torch.onnx.OnnxExporterError as e: | 
 |                     # NOTE: If the model has unsupported nodes, we will skip the test | 
 |                     # with non-strict xfail. Otherwise, we will raise the error. | 
 |                     if hasattr( | 
 |                         e.__cause__, "diagnostic" | 
 |                     ) and e.__cause__.diagnostic.rule in ( | 
 |                         _rules._POERules.no_symbolic_function_for_call_function, | 
 |                         _rules._POERules.unsupported_fx_node_analysis, | 
 |                     ): | 
 |                         pytest.xfail( | 
 |                             onnx_test_common.reason_onnx_script_does_not_support(str(e)) | 
 |                         ) | 
 |                     else: | 
 |                         raise e | 
 |                 _compare_onnx_and_torch_exported_program( | 
 |                     model, | 
 |                     onnx_program, | 
 |                     inputs, | 
 |                     test_name=test_suite.id(), | 
 |                     sample_num=i, | 
 |                     sample_kwargs=cpu_sample.kwargs, | 
 |                     rtol=rtol, | 
 |                     atol=atol, | 
 |                     only_check_shape=(op.name in test_suite.only_shape_check_list), | 
 |                 ) | 
 |  | 
 |  | 
 | def _parameterized_class_attrs_and_values(): | 
 |     input_values = [] | 
 |     input_values.extend( | 
 |         itertools.product( | 
 |             (opset for opset in onnx_test_common.FX_TESTED_OPSETS), | 
 |             ( | 
 |                 pytorch_test_common.TorchModelType.TORCH_NN_MODULE, | 
 |                 pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, | 
 |             ), | 
 |         ) | 
 |     ) | 
 |     return { | 
 |         "attrs": ["opset_version", "model_type"], | 
 |         "input_values": input_values, | 
 |     } | 
 |  | 
 |  | 
 | def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): | 
 |     """Combine class name with the parameterized arguments. | 
 |  | 
 |     This function is passed to `parameterized.parameterized_class` as the | 
 |     `class_name_func` argument. | 
 |     """ | 
 |     suffixes = [] | 
 |     for k, v in input_dicts.items(): | 
 |         suffixes.append(f"{k}_{v}") | 
 |     return f"{cls.__name__}_{'_'.join(suffixes)}" | 
 |  | 
 |  | 
 | @parameterized.parameterized_class( | 
 |     **_parameterized_class_attrs_and_values(), | 
 |     class_name_func=_parameterize_class_name, | 
 | ) | 
 | class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): | 
 |     """Test output consistency between exported ONNX models and PyTorch eager mode. | 
 |  | 
 |     This is a parameterized test suite. | 
 |     """ | 
 |  | 
 |     opset_version = -1 | 
 |     op_level_debug: bool = False | 
 |     dynamic_shapes: bool = False | 
 |     model_type: pytorch_test_common.TorchModelType = ( | 
 |         pytorch_test_common.TorchModelType.TORCH_NN_MODULE | 
 |     ) | 
 |  | 
 |     # NOTE: Follow torchlib settings in ops_test_data.py | 
 |     only_shape_check_list = [ | 
 |         "empty", | 
 |         "empty_like", | 
 |         "empty_strided", | 
 |         "new_empty", | 
 |         "new_empty_strided", | 
 |     ] | 
 |  | 
 |     fp32_low_precision_dict = { | 
 |         "native_layer_norm": [2e-4, 7e-4], | 
 |     } | 
 |  | 
 |     fp16_low_precision_dict = { | 
 |         "addbmm": [2e-1, 2e-2], | 
 |         "addcdiv": [3e-2, 1e-3], | 
 |         "addcmul": [3e-2, 1e-3], | 
 |         "addmv": [5e-2, 3e-2], | 
 |         "addr": [3e-3, 4e-3], | 
 |         "baddbmm": [3e-2, 1e-3], | 
 |         "cumulative_trapezoid": [3e-2, 1e-3], | 
 |         "diff": [1e-2, 5e-2], | 
 |         "gradient": [3e-3, 4e-3], | 
 |         "linalg.multi_dot": [3e-2, 1e-3], | 
 |         "linalg.vecdot": [1e-2, 2e-2], | 
 |         "linspace": [2e-2, 2e-3], | 
 |         "masked.std": [2e-2, 2e-3], | 
 |         "masked.var": [2e-2, 2e-2], | 
 |         "matmul": [2e-2, 6e-2], | 
 |         "nn.functional.batch_norm": [3e-2, 1e-3], | 
 |         "nn.functional.binary_cross_entropy": [3e-2, 1e-3], | 
 |         "nn.functional.binary_cross_entropy_with_logits": [3e-2, 1e-3], | 
 |         "nn.functional.cosine_similarity": [3e-2, 1e-3], | 
 |         "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], | 
 |         "nn.functional.hardsigmoid": [1e-3, 5e-3], | 
 |         "nn.functional.hardswish": [1e-3, 5e-3], | 
 |         "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], | 
 |         "nn.functional.instance_norm": [1e-2, 1e-3], | 
 |         "nn.functional.interpolate": [1e-2, 1e-3], | 
 |         "nn.functional.kl_div": [2e-3, 2e-4], | 
 |         "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], | 
 |         "nn.functional.local_response_norm": [1e-2, 5e-3], | 
 |         "nn.functional.poisson_nll_loss": [3e-2, 1e-3], | 
 |         "native_batch_norm": [3e-2, 1e-3], | 
 |         "dot": [3e-2, 1e-3], | 
 |         "logit": [3e-2, 1e-3], | 
 |         "rsub": [3e-2, 1e-3], | 
 |         "sinc": [2e-1, 6e-4], | 
 |         "sub": [3e-2, 1e-3], | 
 |         "trapezoid": [1e-3, 7e-3], | 
 |         "trapz": [1e-3, 7e-3], | 
 |     } | 
 |  | 
 |     fp16_low_precision_variant_dict = { | 
 |         ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], | 
 |         ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], | 
 |     } | 
 |  | 
 |     @common_device_type.ops( | 
 |         [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], | 
 |         allowed_dtypes=onnx_test_common.TESTED_DTYPES, | 
 |     ) | 
 |     def test_output_match(self, device: str, dtype: torch.dtype, op): | 
 |         """Test the ONNX exporter.""" | 
 |         _run_test_output_match(self, device, dtype, op) | 
 |  | 
 |  | 
 | for opset in onnx_test_common.FX_TESTED_OPSETS: | 
 |     for model_type in pytorch_test_common.TorchModelType: | 
 |         # The name needs to match the parameterized_class name. | 
 |         test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" | 
 |         onnx_test_common.add_decorate_info( | 
 |             OPS_DB, | 
 |             test_class_name, | 
 |             "test_output_match", | 
 |             opset=opset, | 
 |             skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, | 
 |         ) | 
 |  | 
 |         common_device_type.instantiate_device_type_tests( | 
 |             globals()[test_class_name], globals(), only_for="cpu" | 
 |         ) | 
 |  | 
 | if __name__ == "__main__": | 
 |     common_utils.run_tests() |