[MPS] Fix the cat op for NHWC case (#94662)
* add unit test cat with non-contiguous
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94662
Approved by: https://github.com/DenisVieriu97
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index 000dbd3..4127dda 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -224,6 +224,7 @@
const Tensor& out) {
using namespace mps;
+
if (out.numel() == 0) {
return;
}
@@ -288,6 +289,10 @@
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
notSkippedTensor.device(), " and out is on ", out.device());
+ // TODO: For better performance by eliminating input tensor gathering and post transpose,
+ // TODO: it is better to keep the out tensor's memory format.
+ // TODO: dimension needs to be recomputed as:
+ // TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) {
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
@@ -308,7 +313,7 @@
size[dimension] = cat_dim_size;
// skip resizing if size of result is same as expected
if (out.sizes() != size) {
- out.resize_(size, memory_format);
+ out.resize_(size, MemoryFormat::Contiguous);
}
if (out.numel() == 0) {
return;
@@ -344,7 +349,7 @@
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
- newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format));
+ newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
if (tensor.scalar_type() != out_dtype) {
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
toType:getMPSDataType(out_dtype)
@@ -364,8 +369,7 @@
toType:MPSDataTypeBool
name:@"outputTensor"];
}
- newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ?
- convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
@@ -381,8 +385,8 @@
scalar_type = MPSDataTypeInt8;
}
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
- getMPSShape(tensor, memory_format),
- memory_format != MemoryFormat::ChannelsLast, scalar_type);
+ getMPSShape(tensor, MemoryFormat::Contiguous),
+ /*gatherTensorData*/true, scalar_type);
t_idx++;
}
i++;
diff --git a/test/test_mps.py b/test/test_mps.py
index fc7b475..bd3f5c1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2186,16 +2186,25 @@
# See https://github.com/pytorch/pytorch/issues/85675
def test_cat_non_contiguous(self):
- def rotate_subset(data):
- return torch.concat([data[:, :2], torch.rot90(data[:, 2:])])
+ def rotate_subset(data, dim):
+ x1 = data[:, :, :2, :]
+ x2 = data[:, :, 2:, :]
+ self.assertFalse(x1.is_contiguous())
+ self.assertFalse(x2.is_contiguous())
+ return torch.concat((x1, x2), dim=dim)
for dtype in MPS_DTYPES:
if dtype == torch.bool:
continue
- data = torch.arange(8, dtype=dtype).reshape(2, 4)
+ data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
+ data = data.to(memory_format=torch.channels_last)
mps_data = data.to("mps")
- cpu_result = rotate_subset(data)
- mps_result = rotate_subset(mps_data)
- self.assertEqual(cpu_result, mps_result.to("cpu"))
+ self.assertEqual(data, mps_data)
+ for dim in range(data.dim()):
+ cpu_result = rotate_subset(data, dim)
+ mps_result = rotate_subset(mps_data, dim)
+ self.assertEqual(cpu_result, mps_result.to("cpu"))
+ # TODO: enable memory format test
+ # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
# See https://github.com/pytorch/pytorch/issues/85967
def test_from_numpy_non_contiguous(self):