[MPS][BE] Error-check linear (#124952)
Validate that all arguments are on MPS devices and dtypes are expected
Fixes cryptic messages like
```
% python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32), torch.rand((32, 32), device='mps')))"
RuntimeError: Placeholder storage has not been allocated on MPS device!
```
And hard crashes like
```
% python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32, device='mps'), torch.randint(-10, 10, (32, 32), dtype=torch.int8, device='mps')))"
```
Fixes https://github.com/pytorch/pytorch/issues/123995
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124952
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm
index 6ed9853..450e24c 100644
--- a/aten/src/ATen/native/mps/operations/Linear.mm
+++ b/aten/src/ATen/native/mps/operations/Linear.mm
@@ -16,9 +16,16 @@
auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
TORCH_CHECK(supportedFloatingType(input), "MPS device does not support linear for non-float inputs");
+ TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
+ TORCH_CHECK(supportedFloatingType(weight_arg), "MPS device does not support linear for non-float weights");
+ TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps");
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
- bool is_bias_defined = bias.defined();
+ const bool is_bias_defined = bias.defined();
+ if (is_bias_defined) {
+ TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps");
+ TORCH_CHECK(supportedFloatingType(bias), "MPS device does not support linear for non-float bias");
+ }
auto input_size = input.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
diff --git a/test/test_mps.py b/test/test_mps.py
index bfac420..7f87c1c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1961,6 +1961,25 @@
helper(())
helper((2, 4))
+ def test_linear_errors(self):
+ # Mixed CPU<->MPS tensors
+ size = (3, 3)
+
+ # Unsupported dtypes
+ with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
+ torch.nn.functional.linear(torch.rand(size, device='mps'),
+ torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
+
+ # Weigths on wrong device
+ with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
+ torch.nn.functional.linear(torch.rand(size, device='mps'),
+ torch.rand(size, device='cpu'))
+
+ # Input on wrong device
+ with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
+ torch.nn.functional.linear(torch.rand(size, device='cpu'),
+ torch.rand(size, device='mps'))
+
def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)