Add out wrappers to some decompositions (#115437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index b330aa7..dc3d8cc 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -22,6 +22,9 @@
aten::_softmax.out
aten::_to_copy
aten::_to_copy.out
+aten::_upsample_nearest_exact1d.out
+aten::_upsample_nearest_exact2d.out
+aten::_upsample_nearest_exact3d.out
aten::abs
aten::abs.out
aten::abs_
@@ -508,6 +511,10 @@
aten::uniform_
aten::unsqueeze
aten::upsample_bicubic2d
+aten::upsample_bicubic2d.out
+aten::upsample_nearest1d.out
+aten::upsample_nearest2d.out
+aten::upsample_nearest3d.out
aten::var.correction
aten::var.correction_out
aten::var_mean.correction
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8fbdc43..2fc26d1 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -609,13 +609,10 @@
aten::_upsample_bilinear2d_aa.out
aten::_upsample_bilinear2d_aa_backward
aten::_upsample_bilinear2d_aa_backward.grad_input
-aten::_upsample_nearest_exact1d.out
aten::_upsample_nearest_exact1d_backward
aten::_upsample_nearest_exact1d_backward.grad_input
-aten::_upsample_nearest_exact2d.out
aten::_upsample_nearest_exact2d_backward
aten::_upsample_nearest_exact2d_backward.grad_input
-aten::_upsample_nearest_exact3d.out
aten::_upsample_nearest_exact3d_backward
aten::_upsample_nearest_exact3d_backward.grad_input
aten::_use_cudnn_ctc_loss
@@ -1331,20 +1328,16 @@
aten::unsqueeze_
aten::unsqueeze_copy
aten::unsqueeze_copy.out
-aten::upsample_bicubic2d.out
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward
aten::upsample_bilinear2d_backward.grad_input
aten::upsample_linear1d_backward
aten::upsample_linear1d_backward.grad_input
-aten::upsample_nearest1d.out
aten::upsample_nearest1d_backward
aten::upsample_nearest1d_backward.grad_input
-aten::upsample_nearest2d.out
aten::upsample_nearest2d_backward
aten::upsample_nearest2d_backward.grad_input
-aten::upsample_nearest3d.out
aten::upsample_nearest3d_backward
aten::upsample_nearest3d_backward.grad_input
aten::upsample_trilinear3d_backward
diff --git a/test/test_torch.py b/test/test_torch.py
index 735a4f4..25d1cc1 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8847,7 +8847,7 @@
out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
- """Expected out tensor to have dtype float, but got double instead"""
+ """Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
)
# Complain if out device mismatch
@@ -8857,7 +8857,7 @@
if not TEST_WITH_TORCHINDUCTOR:
self.assertExpectedRaisesInline(
RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
- """Expected out tensor to have device meta, but got cpu instead"""
+ """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
)
def test_add_meta_scalar(self):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 3ef43ad..3b69cc5 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -2720,9 +2720,10 @@
return indices
-@register_decomposition(aten.upsample_nearest1d.default)
+@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest1d(
input: Tensor,
output_size: List[int],
@@ -2731,9 +2732,12 @@
return _upsample_nearest(input, output_size, [scales])
-@register_decomposition(aten._upsample_nearest_exact1d.default)
+@register_decomposition(
+ [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
+)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest_exact1d(
input: Tensor,
output_size: List[int],
@@ -2742,9 +2746,10 @@
return _upsample_nearest(input, output_size, [scales], exact=True)
-@register_decomposition(aten.upsample_nearest2d.default)
+@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
input: Tensor,
output_size: List[int],
@@ -2754,9 +2759,12 @@
return _upsample_nearest(input, output_size, [scales_h, scales_w])
-@register_decomposition(aten._upsample_nearest_exact2d.default)
+@register_decomposition(
+ [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
+)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact2d(
input: Tensor,
output_size: List[int],
@@ -2766,9 +2774,10 @@
return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
-@register_decomposition(aten.upsample_nearest3d.default)
+@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest3d(
input: Tensor,
output_size: List[int],
@@ -2779,9 +2788,12 @@
return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
-@register_decomposition(aten._upsample_nearest_exact3d.default)
+@register_decomposition(
+ [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
+)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact3d(
input: Tensor,
output_size: List[int],
@@ -4251,8 +4263,9 @@
torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
-@register_decomposition(aten.upsample_bicubic2d.default)
+@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
+@out_wrapper()
@pw_cast_for_opmath
def upsample_bicubic2d_default(
input: Tensor,
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 8b7515b..9057edc 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -170,9 +170,13 @@
# TODO: handle tuples of tensors
-def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+def _maybe_resize_out(
+ out: TensorLikeType,
+ shape: ShapeType,
+ memory_format: Optional[torch.memory_format] = None,
+):
if _resize_output_check(out, shape):
- return out.resize_(shape)
+ return out.resize_(shape, memory_format=memory_format)
else:
return out
@@ -205,7 +209,12 @@
return copy_to.copy_(copy_from)
-def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False):
+def out_wrapper(
+ *out_names: str,
+ exact_dtype: bool = False,
+ pass_is_out: bool = False,
+ preserve_memory_format=False,
+):
# The wrapped function needs to convert the output parameters to ensure
# compatibility between the Python API (which always uses "out" as the
# parameter name and may be a tuple) and the Aten API (which may have
@@ -219,6 +228,9 @@
is_tensor = len(out_names) == 1
+ def maybe_compute_memory_format(t):
+ return utils.suggest_memory_format(t) if preserve_memory_format else None
+
def _out_wrapper(fn: Callable) -> Callable:
"""
Adds the out parameter to a Python reference.
@@ -277,7 +289,9 @@
if is_tensor:
assert isinstance(out, TensorLike)
# These two operations are done in-place
- _maybe_resize_out(out, result.shape)
+ _maybe_resize_out(
+ out, result.shape, maybe_compute_memory_format(result)
+ )
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
assert isinstance(out, Tuple) # type: ignore[arg-type]
@@ -287,7 +301,7 @@
)
for r, o in zip(result, out):
# These two operations are done in-place
- _maybe_resize_out(o, r.shape)
+ _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
out = result