[MPS] Handle int inputs of matmul ops by returning error for unsupported data types (#82183)
This is in-continuation of fixes for TestConsistency for MPS backend.
* Add error messages for unsupported matmul ops
* Add error handling for int inputs for linear op
### Description
<!-- What did you change and why was it needed? -->
### Issue
<!-- Link to Issue ticket or RFP -->
### Testing
<!-- How did you test your change? -->
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82183
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm
index 181ee8d..a6710ea 100644
--- a/aten/src/ATen/native/mps/operations/Linear.mm
+++ b/aten/src/ATen/native/mps/operations/Linear.mm
@@ -25,6 +25,10 @@
using namespace mps;
+ TORCH_CHECK(input.scalar_type() == ScalarType::Double
+ || input.scalar_type() == ScalarType::Float
+ || input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs");
+
// See [Note: hacky wrapper removal for optional tensor]
auto bias = bias_opt.has_value()
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
@@ -156,6 +160,10 @@
TORCH_CHECK(weight.device().is_mps() && weight.scalar_type() == kFloat,
"mps_linear_backward: weight needs to be a dense tensor");
+ TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double
+ || grad_output.scalar_type() == ScalarType::Float
+ || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
+
const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous();
struct CachedGraph : public mps::MPSCachedGraph
@@ -232,6 +240,10 @@
TORCH_CHECK(grad_output.is_mps() && input.is_mps(),
"_mps_linear_backward: grad_output and input needs to be mps layout");
+ TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double
+ || grad_output.scalar_type() == ScalarType::Float
+ || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
+
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 9306de8..8b69c65 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -100,6 +100,9 @@
Tensor& output) {
using namespace mps;
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
+ TORCH_CHECK(self.scalar_type() == ScalarType::Double
+ || self.scalar_type() == ScalarType::Float
+ || self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs");
TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}};
checkAllSameGPU("mm", args);
@@ -208,6 +211,9 @@
TORCH_CHECK(output.is_mps());
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
+ TORCH_CHECK(self.scalar_type() == ScalarType::Double
+ || self.scalar_type() == ScalarType::Float
+ || self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input");
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
checkAllSameGPU(__func__, args);
@@ -366,6 +372,10 @@
Tensor & result) {
using namespace mps;
+ TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
+ || batch1.scalar_type() == ScalarType::Float
+ || batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
+
if (batch1.numel() == 0 || batch2.numel() == 0) {
return result;
}
@@ -444,6 +454,10 @@
TORCH_CHECK(batch2.is_mps());
TORCH_CHECK(result.is_mps());
+ TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
+ || batch1.scalar_type() == ScalarType::Float
+ || batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs");
+
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
TORCH_CHECK(batch1.size(0) == batch2.size(0),