Add decomposition for unsqueeze_copy (#130942)
* Extracted from #128416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130942
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 1ebf152..01de3aa 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1316,8 +1316,6 @@
aten::unsafe_split.Tensor_out
aten::unsafe_split_with_sizes.out
aten::unsqueeze_
-aten::unsqueeze_copy
-aten::unsqueeze_copy.out
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 911c71e..03744b7 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1429,6 +1429,7 @@
xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled
xfail("t_copy"),
+ xfail("unsqueeze_copy"),
}
),
)
@@ -1566,6 +1567,7 @@
"index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled
xfail("t_copy"),
+ xfail("unsqueeze_copy"),
}
),
)
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index fb10f22..6698dc4 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -4363,6 +4363,7 @@
xfail("as_strided"),
xfail("as_strided_copy"),
xfail("t_copy"),
+ xfail("unsqueeze_copy"),
xfail("istft"),
xfail("nonzero"),
xfail("nn.functional.fractional_max_pool2d"),
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index 4129a29..9007d8a 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -1264,6 +1264,11 @@
reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
),
xfail(
+ "unsqueeze_copy",
+ reason="OnnxExporterError: Failed to export model",
+ dtypes=(torch.int8, torch.uint8, torch.int16),
+ ),
+ xfail(
"where",
dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
diff --git a/test/test_mps.py b/test/test_mps.py
index 74bf819..90fe6e8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -352,6 +352,7 @@
'unsafe_chunk',
'unsafe_split',
'unsqueeze',
+ 'unsqueeze_copy',
'view_as',
'view_as_real',
'view',
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index ad1da61..a777823 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -200,6 +200,7 @@
"permute",
"squeeze",
"unsqueeze",
+ "unsqueeze_copy",
"resize",
"resize_as",
"tril",
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 0e451eb..984e9aa 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -460,6 +460,7 @@
aten._unsafe_masked_index_put_accumulate,
aten.unsafe_split.Tensor,
aten.unsafe_split_with_sizes,
+ aten.unsqueeze_copy,
aten._unsafe_view,
aten.upsample_linear1d,
aten.upsample_bilinear2d,
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index e28cb13..3c8e9d9 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -294,6 +294,7 @@
"unfold",
"unfold_copy",
"unsqueeze",
+ "unsqueeze_copy",
"view",
"view_as",
"view_copy",
@@ -6321,6 +6322,7 @@
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
t_copy = _make_copy_from_view(aten.t)
+unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 816aa32..f1e4dae 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19653,6 +19653,29 @@
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
sample_inputs_func=sample_unsqueeze),
+ OpInfo('unsqueeze_copy',
+ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+ supports_out=True,
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ # See https://github.com/pytorch/pytorch/pull/78358
+ check_batched_forward_grad=False,
+ # vmap does not support inplace views
+ check_inplace_batched_forward_grad=False,
+ assert_jit_shape_analysis=True,
+ assert_autodiffed=True,
+ autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
+ autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
+ sample_inputs_func=sample_unsqueeze,
+ skips=(
+ DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+ DecorateInfo(
+ unittest.expectedFailure,
+ 'TestJit',
+ 'test_variant_consistency_jit',
+ dtypes=(torch.float32,),
+ ),
+ )),
BinaryUfuncInfo('xlogy',
aliases=('special.xlogy',),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
@@ -23947,6 +23970,11 @@
torch_opinfo_name="unsqueeze",
),
PythonRefInfo(
+ "_refs.unsqueeze_copy",
+ torch_opinfo_name="unsqueeze_copy",
+ supports_out=True,
+ ),
+ PythonRefInfo(
"_refs.view",
torch_opinfo_name="view",
),