[MPS] Fix sort with empty tensor. (#109584)
Fixes #107284
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109584
Approved by: https://github.com/kulinseth, https://github.com/albanD
ghstack dependencies: #109557, #109574
diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm
index e29bc4b..bb12aa6 100644
--- a/aten/src/ATen/native/mps/operations/Sort.mm
+++ b/aten/src/ATen/native/mps/operations/Sort.mm
@@ -29,6 +29,10 @@
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
+ if (self.numel() == 0) {
+ return;
+ }
+
values.copy_(self);
// check if self is scalar
dim = maybe_wrap_dim(dim, self.dim(), true);
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 53509f7..80d0a06 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3181,6 +3181,11 @@
yield SampleInput(torch.tensor(1, **tensor_opt), 0)
yield SampleInput(torch.tensor(1, **tensor_opt), 0, True)
+ # Test cases for empty tensor
+ yield SampleInput(torch.tensor((), **tensor_opt))
+ yield SampleInput(torch.tensor((), **tensor_opt), 0)
+ yield SampleInput(torch.tensor((), **tensor_opt), 0, True)
+
# Test cases for stable sort
yield SampleInput(small_3d_unique(), stable=True)
yield SampleInput(small_3d_unique(), dim=0, stable=True)