[MPS] Fix conv1d backwards crash for channels last case (#85283)
Fixes pytorch#84511
Use the same logic as in the forward pass for the backward pass (in case of channels last, manually set the shape to NHWC)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85283
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm
index 41d68d4..88bad9a 100644
--- a/aten/src/ATen/native/mps/operations/Convolution.mm
+++ b/aten/src/ATen/native/mps/operations/Convolution.mm
@@ -39,6 +39,19 @@
descriptor_.groups = groups;
}
+static
+MPSShape* get_mps_conv_shape(const Tensor& tensor, bool is_channels_last) {
+ if (is_channels_last) {
+ const auto tensorSizes = tensor.sizes();
+ const NSUInteger N = tensorSizes[0];
+ const NSUInteger C = tensorSizes[1];
+ const NSUInteger H = tensorSizes[2];
+ const NSUInteger W = tensorSizes[3];
+ return @[@(N), @(H), @(W), @(C)];
+ }
+ return at::native::mps::getMPSShape(tensor);
+}
+
Tensor _mps_convolution(
const Tensor& input_t,
const Tensor& weight_t,
@@ -126,19 +139,7 @@
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- MPSShape* inputShape = nil;
-
- if (is_channels_last) {
- const auto inputSizes = input_t.sizes();
- const NSUInteger N = inputSizes[0];
- const NSUInteger C = inputSizes[1];
- const NSUInteger H = inputSizes[2];
- const NSUInteger W = inputSizes[3];
- inputShape = @[@(N), @(H), @(W), @(C)];
- } else {
- inputShape = native_mps::getMPSShape(input_t);
- }
-
+ MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
@@ -333,6 +334,9 @@
using namespace mps;
CheckedFrom c = "mps_convolution_backward_weights";
auto memory_format = input_t.suggest_memory_format();
+ bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
+ MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
+ MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_t, is_channels_last);
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
@@ -399,8 +403,8 @@
padding[1], padding[0],
memory_format, groups);
- MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
- MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+ MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
+ MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor
sourceTensor:inputTensor
@@ -417,8 +421,8 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
- auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
+ auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
+ auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
diff --git a/test/test_mps.py b/test/test_mps.py
index e036f69..ccb3a29 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6005,6 +6005,19 @@
self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
+ def test_conv_backward_1d_channels_last(self):
+ # https://github.com/pytorch/pytorch/issues/84511
+ conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
+ conv_mps = copy.deepcopy(conv_cpu).to(device='mps')
+
+ data = torch.rand(1, 176, 1, dtype=torch.float32)
+ x_cpu = data.permute(0, 2, 1).contiguous()
+ x_mps = data.permute(0, 2, 1).contiguous().to("mps")
+ res_cpu = conv_cpu(x_cpu).sum().backward()
+ res_mps = conv_mps(x_mps).sum().backward()
+
+ self.assertEqual(res_cpu, res_mps)
+
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)