[ONNX] Update decomposition table to core ATen ops (#127353)
Fixes #125894
Previous to this PR, there are ATen core ops missing in the decomposition table because we thought they might be decomposed into prim ops, as they are under _refs. The PR picks them back according to https://github.com/pytorch/pytorch/blob/f6ef832e87a8ea01e6df93b27a2367cccb6b6171/torch/_decomp/__init__.py#L253
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127353
Approved by: https://github.com/justinchuby
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index 760ede5..4a41716 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -311,6 +311,11 @@
reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"),
),
xfail(
+ "block_diag",
+ dtypes=onnx_test_common.COMPLEX_TYPES,
+ reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"),
+ ),
+ xfail(
"bmm",
dtypes=(
torch.uint8,
@@ -408,10 +413,6 @@
reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"),
),
xfail(
- "cross",
- reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
- ),
- xfail(
"diag",
dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"),
@@ -545,6 +546,11 @@
reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64")
),
xfail(
+ "index_fill",
+ dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES,
+ reason="fixme: Constant input list has None. ONNXScript does not support None in constant list."
+ ),
+ xfail(
"index_put",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"),
@@ -587,6 +593,10 @@
reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"),
),
xfail(
+ "linalg.matrix_power",
+ reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])."
+ ),
+ xfail(
"linalg.norm",
reason="fixme: Assertion error: result mismatch",
),
@@ -964,6 +974,10 @@
reason="fixme: ONNX Runtime does not support int32/64 inputs",
),
xfail(
+ "nn.functional.pixel_unshuffle",
+ reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"),
+ ),
+ xfail(
"nn.functional.poisson_nll_loss",
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason="fixme: result mismatch with NaN.",
@@ -1108,6 +1122,11 @@
reason="ONNX doesn't support reduce='mean' option",
),
xfail(
+ "sgn",
+ dtypes=onnx_test_common.BOOL_TYPES,
+ reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"),
+ ),
+ xfail(
"sign",
dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"),
@@ -1141,6 +1160,11 @@
reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"),
),
xfail(
+ "special.log_ndtr",
+ dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES,
+ reason="fixme: Assertion error: result mismatch",
+ ),
+ xfail(
"special.ndtr",
dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch",
@@ -1160,15 +1184,6 @@
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
),
xfail(
- "std_mean",
- reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
- ),
- xfail(
- "std_mean",
- variant_name="unbiased",
- reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
- ),
- xfail(
"stft",
reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
),
@@ -1961,8 +1976,10 @@
"addr": [3e-3, 4e-3],
"baddbmm": [3e-2, 1e-3],
"cumulative_trapezoid": [3e-2, 1e-3],
+ "cross": [3e-2, 2e-2],
"diff": [1e-2, 5e-2],
"gradient": [3e-3, 4e-3],
+ "linalg.cross": [1e-3, 2e-2],
"linalg.multi_dot": [3e-2, 1e-3],
"linalg.vecdot": [1e-2, 2e-2],
"linspace": [2e-2, 2e-3],
@@ -1977,6 +1994,7 @@
"nn.functional.hardsigmoid": [1e-3, 5e-3],
"nn.functional.hardswish": [1e-3, 5e-3],
"nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
+ "nn.functional.huber_loss": [1e-3, 1e-2],
"nn.functional.instance_norm": [1e-2, 1e-3],
"nn.functional.interpolate": [1e-2, 1e-3],
"nn.functional.kl_div": [2e-3, 2e-4],
@@ -1984,6 +2002,8 @@
"nn.functional.local_response_norm": [1e-2, 5e-3],
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
"nn.functional.nll_loss": [3e-2, 1e-3],
+ "nn.functional.triplet_margin_loss": [2e-2, 1e-2],
+ "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2],
"native_batch_norm": [3e-2, 1e-3],
"norm": [1e-2, 1e-2],
"dot": [3e-2, 1e-3],
@@ -1993,6 +2013,7 @@
"sub": [3e-2, 1e-3],
"trapezoid": [1e-3, 7e-3],
"trapz": [1e-3, 7e-3],
+ "vdot": [1e-3, 1e-2],
}
fp16_low_precision_variant_dict = {
diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
index 5345e02..b70bfbf 100644
--- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py
+++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
@@ -1275,7 +1275,7 @@
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
)
@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
- error_message="aot_autograd expected to have an entirely functional graph",
+ error_message="n=copy_, n.args[0]=zeros_like, placeholders={",
reason="aot_autograd doesn't support it.",
)
def test_fake_tensor_mode_huggingface_openai_whisper(self):
diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py
index 4f3f705..5cb9be6 100644
--- a/torch/onnx/_internal/fx/decomposition_table.py
+++ b/torch/onnx/_internal/fx/decomposition_table.py
@@ -111,4 +111,12 @@
):
continue
decomposition_table[op_overload] = decomp_fn
+
+ # NOTE: There are ops in core ATen and under torch._refs,
+ # that are not decomposed to prim::ops. We need to pick them
+ # back
+ for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items():
+ if op_overload in _ONNX_SUPPORT_OP_OVERLOADS:
+ continue
+ decomposition_table[op_overload] = decomp_fn
return decomposition_table