Add out wrappers to some decompositions (#115437)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index b330aa7..dc3d8cc 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -22,6 +22,9 @@
 aten::_softmax.out
 aten::_to_copy
 aten::_to_copy.out
+aten::_upsample_nearest_exact1d.out
+aten::_upsample_nearest_exact2d.out
+aten::_upsample_nearest_exact3d.out
 aten::abs
 aten::abs.out
 aten::abs_
@@ -508,6 +511,10 @@
 aten::uniform_
 aten::unsqueeze
 aten::upsample_bicubic2d
+aten::upsample_bicubic2d.out
+aten::upsample_nearest1d.out
+aten::upsample_nearest2d.out
+aten::upsample_nearest3d.out
 aten::var.correction
 aten::var.correction_out
 aten::var_mean.correction
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8fbdc43..2fc26d1 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -609,13 +609,10 @@
 aten::_upsample_bilinear2d_aa.out
 aten::_upsample_bilinear2d_aa_backward
 aten::_upsample_bilinear2d_aa_backward.grad_input
-aten::_upsample_nearest_exact1d.out
 aten::_upsample_nearest_exact1d_backward
 aten::_upsample_nearest_exact1d_backward.grad_input
-aten::_upsample_nearest_exact2d.out
 aten::_upsample_nearest_exact2d_backward
 aten::_upsample_nearest_exact2d_backward.grad_input
-aten::_upsample_nearest_exact3d.out
 aten::_upsample_nearest_exact3d_backward
 aten::_upsample_nearest_exact3d_backward.grad_input
 aten::_use_cudnn_ctc_loss
@@ -1331,20 +1328,16 @@
 aten::unsqueeze_
 aten::unsqueeze_copy
 aten::unsqueeze_copy.out
-aten::upsample_bicubic2d.out
 aten::upsample_bicubic2d_backward
 aten::upsample_bicubic2d_backward.grad_input
 aten::upsample_bilinear2d_backward
 aten::upsample_bilinear2d_backward.grad_input
 aten::upsample_linear1d_backward
 aten::upsample_linear1d_backward.grad_input
-aten::upsample_nearest1d.out
 aten::upsample_nearest1d_backward
 aten::upsample_nearest1d_backward.grad_input
-aten::upsample_nearest2d.out
 aten::upsample_nearest2d_backward
 aten::upsample_nearest2d_backward.grad_input
-aten::upsample_nearest3d.out
 aten::upsample_nearest3d_backward
 aten::upsample_nearest3d_backward.grad_input
 aten::upsample_trilinear3d_backward
diff --git a/test/test_torch.py b/test/test_torch.py
index 735a4f4..25d1cc1 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8847,7 +8847,7 @@
         out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
         self.assertExpectedRaisesInline(
             RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
-            """Expected out tensor to have dtype float, but got double instead"""
+            """Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
         )
 
         # Complain if out device mismatch
@@ -8857,7 +8857,7 @@
         if not TEST_WITH_TORCHINDUCTOR:
             self.assertExpectedRaisesInline(
                 RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
-                """Expected out tensor to have device meta, but got cpu instead"""
+                """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
             )
 
     def test_add_meta_scalar(self):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 3ef43ad..3b69cc5 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -2720,9 +2720,10 @@
     return indices
 
 
-@register_decomposition(aten.upsample_nearest1d.default)
+@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
 @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def upsample_nearest1d(
     input: Tensor,
     output_size: List[int],
@@ -2731,9 +2732,12 @@
     return _upsample_nearest(input, output_size, [scales])
 
 
-@register_decomposition(aten._upsample_nearest_exact1d.default)
+@register_decomposition(
+    [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
+)
 @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def upsample_nearest_exact1d(
     input: Tensor,
     output_size: List[int],
@@ -2742,9 +2746,10 @@
     return _upsample_nearest(input, output_size, [scales], exact=True)
 
 
-@register_decomposition(aten.upsample_nearest2d.default)
+@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
 @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def upsample_nearest2d(
     input: Tensor,
     output_size: List[int],
@@ -2754,9 +2759,12 @@
     return _upsample_nearest(input, output_size, [scales_h, scales_w])
 
 
-@register_decomposition(aten._upsample_nearest_exact2d.default)
+@register_decomposition(
+    [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
+)
 @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def _upsample_nearest_exact2d(
     input: Tensor,
     output_size: List[int],
@@ -2766,9 +2774,10 @@
     return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
 
 
-@register_decomposition(aten.upsample_nearest3d.default)
+@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
 @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def upsample_nearest3d(
     input: Tensor,
     output_size: List[int],
@@ -2779,9 +2788,12 @@
     return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
 
 
-@register_decomposition(aten._upsample_nearest_exact3d.default)
+@register_decomposition(
+    [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
+)
 @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
 def _upsample_nearest_exact3d(
     input: Tensor,
     output_size: List[int],
@@ -4251,8 +4263,9 @@
         torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
 
 
-@register_decomposition(aten.upsample_bicubic2d.default)
+@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
 @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper()
 @pw_cast_for_opmath
 def upsample_bicubic2d_default(
     input: Tensor,
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 8b7515b..9057edc 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -170,9 +170,13 @@
 
 
 # TODO: handle tuples of tensors
-def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+def _maybe_resize_out(
+    out: TensorLikeType,
+    shape: ShapeType,
+    memory_format: Optional[torch.memory_format] = None,
+):
     if _resize_output_check(out, shape):
-        return out.resize_(shape)
+        return out.resize_(shape, memory_format=memory_format)
     else:
         return out
 
@@ -205,7 +209,12 @@
     return copy_to.copy_(copy_from)
 
 
-def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False):
+def out_wrapper(
+    *out_names: str,
+    exact_dtype: bool = False,
+    pass_is_out: bool = False,
+    preserve_memory_format=False,
+):
     # The wrapped function needs to convert the output parameters to ensure
     # compatibility between the Python API (which always uses "out" as the
     # parameter name and may be a tuple) and the Aten API (which may have
@@ -219,6 +228,9 @@
 
     is_tensor = len(out_names) == 1
 
+    def maybe_compute_memory_format(t):
+        return utils.suggest_memory_format(t) if preserve_memory_format else None
+
     def _out_wrapper(fn: Callable) -> Callable:
         """
         Adds the out parameter to a Python reference.
@@ -277,7 +289,9 @@
                 if is_tensor:
                     assert isinstance(out, TensorLike)
                     # These two operations are done in-place
-                    _maybe_resize_out(out, result.shape)
+                    _maybe_resize_out(
+                        out, result.shape, maybe_compute_memory_format(result)
+                    )
                     _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype)  # type: ignore[arg-type]
                 else:
                     assert isinstance(out, Tuple)  # type: ignore[arg-type]
@@ -287,7 +301,7 @@
                     )
                     for r, o in zip(result, out):
                         # These two operations are done in-place
-                        _maybe_resize_out(o, r.shape)
+                        _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
                         _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype)  # type: ignore[arg-type]
             else:
                 out = result