Add decomposition for unsqueeze_copy (#130942)

* Extracted from #128416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130942
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 1ebf152..01de3aa 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1316,8 +1316,6 @@
 aten::unsafe_split.Tensor_out
 aten::unsafe_split_with_sizes.out
 aten::unsqueeze_
-aten::unsqueeze_copy
-aten::unsqueeze_copy.out
 aten::upsample_bicubic2d_backward
 aten::upsample_bicubic2d_backward.grad_input
 aten::upsample_bilinear2d_backward
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 911c71e..03744b7 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1429,6 +1429,7 @@
                 xfail("masked.cumprod", ""),
                 xfail("renorm"),  # hit vmap fallback, which is disabled
                 xfail("t_copy"),
+                xfail("unsqueeze_copy"),
             }
         ),
     )
@@ -1566,6 +1567,7 @@
                     "index_fill"
                 ),  # aten::_unique hit the vmap fallback which is currently disabled
                 xfail("t_copy"),
+                xfail("unsqueeze_copy"),
             }
         ),
     )
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index fb10f22..6698dc4 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -4363,6 +4363,7 @@
                 xfail("as_strided"),
                 xfail("as_strided_copy"),
                 xfail("t_copy"),
+                xfail("unsqueeze_copy"),
                 xfail("istft"),
                 xfail("nonzero"),
                 xfail("nn.functional.fractional_max_pool2d"),
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index 4129a29..9007d8a 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -1264,6 +1264,11 @@
         reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
     ),
     xfail(
+        "unsqueeze_copy",
+        reason="OnnxExporterError: Failed to export model",
+        dtypes=(torch.int8, torch.uint8, torch.int16),
+    ),
+    xfail(
         "where",
         dtypes=onnx_test_common.BOOL_TYPES,
         reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
diff --git a/test/test_mps.py b/test/test_mps.py
index 74bf819..90fe6e8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -352,6 +352,7 @@
         'unsafe_chunk',
         'unsafe_split',
         'unsqueeze',
+        'unsqueeze_copy',
         'view_as',
         'view_as_real',
         'view',
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index ad1da61..a777823 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -200,6 +200,7 @@
     "permute",
     "squeeze",
     "unsqueeze",
+    "unsqueeze_copy",
     "resize",
     "resize_as",
     "tril",
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 0e451eb..984e9aa 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -460,6 +460,7 @@
             aten._unsafe_masked_index_put_accumulate,
             aten.unsafe_split.Tensor,
             aten.unsafe_split_with_sizes,
+            aten.unsqueeze_copy,
             aten._unsafe_view,
             aten.upsample_linear1d,
             aten.upsample_bilinear2d,
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index e28cb13..3c8e9d9 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -294,6 +294,7 @@
     "unfold",
     "unfold_copy",
     "unsqueeze",
+    "unsqueeze_copy",
     "view",
     "view_as",
     "view_copy",
@@ -6321,6 +6322,7 @@
 # no sparse support. See narrow_copy_sparse in core.
 narrow_copy = _make_copy_from_view(aten.narrow)
 t_copy = _make_copy_from_view(aten.t)
+unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
 view_copy = _make_copy_from_view(aten.view)
 
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 816aa32..f1e4dae 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19653,6 +19653,29 @@
            autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
            autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
            sample_inputs_func=sample_unsqueeze),
+    OpInfo('unsqueeze_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           assert_jit_shape_analysis=True,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           sample_inputs_func=sample_unsqueeze,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+               DecorateInfo(
+                   unittest.expectedFailure,
+                   'TestJit',
+                   'test_variant_consistency_jit',
+                   dtypes=(torch.float32,),
+               ),
+           )),
     BinaryUfuncInfo('xlogy',
                     aliases=('special.xlogy',),
                     dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
@@ -23947,6 +23970,11 @@
         torch_opinfo_name="unsqueeze",
     ),
     PythonRefInfo(
+        "_refs.unsqueeze_copy",
+        torch_opinfo_name="unsqueeze_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
         "_refs.view",
         torch_opinfo_name="view",
     ),