[MPS] Fix the cat op for NHWC case (#94662)

* add unit test cat with non-contiguous

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94662
Approved by: https://github.com/DenisVieriu97
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index 000dbd3..4127dda 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -224,6 +224,7 @@
        const Tensor& out) {
 
   using namespace mps;
+
   if (out.numel() == 0) {
     return;
   }
@@ -288,6 +289,10 @@
               "torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
               notSkippedTensor.device(), " and out is on ", out.device());
 
+  // TODO: For better performance by eliminating input tensor gathering and post transpose,
+  // TODO: it is better to keep the out tensor's memory format.
+  // TODO: dimension needs to be recomputed as:
+  // TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
   if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) {
     out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
   }
@@ -308,7 +313,7 @@
   size[dimension] = cat_dim_size;
   // skip resizing if size of result is same as expected
   if (out.sizes() != size) {
-    out.resize_(size, memory_format);
+    out.resize_(size, MemoryFormat::Contiguous);
   }
   if (out.numel() == 0) {
     return;
@@ -344,7 +349,7 @@
             if (tensor.scalar_type() == kBool) {
               scalar_type = MPSDataTypeInt8;
             }
-            newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format));
+            newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
             if (tensor.scalar_type() != out_dtype) {
               castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
                                                     toType:getMPSDataType(out_dtype)
@@ -364,8 +369,7 @@
                                          toType:MPSDataTypeBool
                                            name:@"outputTensor"];
           }
-          newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ?
-                                         convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
+          newCachedGraph->outputTensor_ = outputTensor;
         }
         return newCachedGraph;
       });
@@ -381,8 +385,8 @@
           scalar_type = MPSDataTypeInt8;
         }
         inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
-                                       getMPSShape(tensor, memory_format),
-                                       memory_format != MemoryFormat::ChannelsLast, scalar_type);
+                                       getMPSShape(tensor, MemoryFormat::Contiguous),
+                                       /*gatherTensorData*/true, scalar_type);
         t_idx++;
       }
       i++;
diff --git a/test/test_mps.py b/test/test_mps.py
index fc7b475..bd3f5c1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2186,16 +2186,25 @@
 
     # See https://github.com/pytorch/pytorch/issues/85675
     def test_cat_non_contiguous(self):
-        def rotate_subset(data):
-            return torch.concat([data[:, :2], torch.rot90(data[:, 2:])])
+        def rotate_subset(data, dim):
+            x1 = data[:, :, :2, :]
+            x2 = data[:, :, 2:, :]
+            self.assertFalse(x1.is_contiguous())
+            self.assertFalse(x2.is_contiguous())
+            return torch.concat((x1, x2), dim=dim)
         for dtype in MPS_DTYPES:
             if dtype == torch.bool:
                 continue
-            data = torch.arange(8, dtype=dtype).reshape(2, 4)
+            data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
+            data = data.to(memory_format=torch.channels_last)
             mps_data = data.to("mps")
-            cpu_result = rotate_subset(data)
-            mps_result = rotate_subset(mps_data)
-            self.assertEqual(cpu_result, mps_result.to("cpu"))
+            self.assertEqual(data, mps_data)
+            for dim in range(data.dim()):
+                cpu_result = rotate_subset(data, dim)
+                mps_result = rotate_subset(mps_data, dim)
+                self.assertEqual(cpu_result, mps_result.to("cpu"))
+                # TODO: enable memory format test
+                # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
 
     # See https://github.com/pytorch/pytorch/issues/85967
     def test_from_numpy_non_contiguous(self):