[MPS] Fix conv1d backwards crash for channels last case (#85283)

Fixes pytorch#84511

Use the same logic as in the forward pass for the backward pass (in case of channels last, manually set the shape to NHWC)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85283
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm
index 41d68d4..88bad9a 100644
--- a/aten/src/ATen/native/mps/operations/Convolution.mm
+++ b/aten/src/ATen/native/mps/operations/Convolution.mm
@@ -39,6 +39,19 @@
   descriptor_.groups = groups;
 }
 
+static
+MPSShape* get_mps_conv_shape(const Tensor& tensor, bool is_channels_last) {
+  if (is_channels_last) {
+    const auto tensorSizes = tensor.sizes();
+    const NSUInteger N = tensorSizes[0];
+    const NSUInteger C = tensorSizes[1];
+    const NSUInteger H = tensorSizes[2];
+    const NSUInteger W = tensorSizes[3];
+    return @[@(N), @(H), @(W), @(C)];
+  }
+  return at::native::mps::getMPSShape(tensor);
+}
+
 Tensor _mps_convolution(
     const Tensor& input_t,
     const Tensor& weight_t,
@@ -126,19 +139,7 @@
                                     + mps::getTensorsStringKey({input_t, weight_t}) + ":"
                                     + to_string(bias_defined) + ":" + bias_shape_key;
     CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-    MPSShape* inputShape = nil;
-
-    if (is_channels_last) {
-      const auto inputSizes = input_t.sizes();
-      const NSUInteger N = inputSizes[0];
-      const NSUInteger C = inputSizes[1];
-      const NSUInteger H = inputSizes[2];
-      const NSUInteger W = inputSizes[3];
-      inputShape = @[@(N), @(H), @(W), @(C)];
-    } else {
-      inputShape = native_mps::getMPSShape(input_t);
-    }
-
+    MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
     if(!cachedGraph) {
       native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
 
@@ -333,6 +334,9 @@
   using namespace mps;
   CheckedFrom c = "mps_convolution_backward_weights";
   auto memory_format = input_t.suggest_memory_format();
+  bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
+  MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
+  MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_t, is_channels_last);
 
   // For uniformity with everything else, although it seems grad_weight
   // would be unambiguous too.
@@ -399,8 +403,8 @@
                                       padding[1], padding[0],
                                       memory_format, groups);
 
-          MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
-          MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+          MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
+          MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
 
           MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor
                                                                                                  sourceTensor:inputTensor
@@ -417,8 +421,8 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
-    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
+    auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
+    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
     auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
 
     NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
diff --git a/test/test_mps.py b/test/test_mps.py
index e036f69..ccb3a29 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6005,6 +6005,19 @@
 
         self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
 
+    def test_conv_backward_1d_channels_last(self):
+        # https://github.com/pytorch/pytorch/issues/84511
+        conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
+        conv_mps = copy.deepcopy(conv_cpu).to(device='mps')
+
+        data = torch.rand(1, 176, 1, dtype=torch.float32)
+        x_cpu = data.permute(0, 2, 1).contiguous()
+        x_mps = data.permute(0, 2, 1).contiguous().to("mps")
+        res_cpu = conv_cpu(x_cpu).sum().backward()
+        res_mps = conv_mps(x_mps).sum().backward()
+
+        self.assertEqual(res_cpu, res_mps)
+
     def test_conv1d_contiguous(self):
         model_cpu = torch.nn.Conv1d(1, 128, 3)
         a_cpu = torch.ones(128, 1, 176)