[MPS] Exclude int64 dtype from reduction ops (#91272)
Reduction ops don't support int64 data type. This PR takes care to assert when int64 is used for min / max reductions ops.
All other integer dtypes are casted to int32.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91272
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 684837f..43b3336 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -34,6 +34,15 @@
using namespace mps;
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
+ int64_t ndim = t.dim();
+ auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+ for (const auto i: c10::irange(ndim)) {
+ axes[i] = [NSNumber numberWithInteger:i];
+ }
+ return axes;
+}
+
void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
NSMutableArray<NSNumber*> * &apparent_in_shape,
int64_t num_reduce_dims,
@@ -1091,15 +1100,12 @@
Tensor min_max_mps(const Tensor& input_t,
MPSReductionType reduction_type,
const std::string& func_name) {
+ TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
using CachedGraph = MPSUnaryCachedGraph;
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
IntArrayRef input_shape = input_t.sizes();
-
- // Flatten the input tensor to reduce it to one value
- NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
int64_t num_in_elements = c10::multiply_integers(input_shape);
- apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];
Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
@@ -1121,14 +1127,26 @@
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
MPSGraphTensor* outputTensor = nil;
+ MPSGraphTensor* castInputTensor = nil;
+ if (input_t.scalar_type() != ScalarType::Float &&
+ input_t.scalar_type() != ScalarType::Int &&
+ input_t.scalar_type() != ScalarType::Half) {
+ castInputTensor = [mpsGraph castTensor:inputTensor
+ toType:MPSDataTypeInt32
+ name:@"castInputTensor"];
+ } else {
+ castInputTensor = inputTensor;
+ }
+
+ NSArray<NSNumber*>* axes = getTensorAxes(input_t);
if (reduction_type == MPSReductionType::MAX) {
- outputTensor = [mpsGraph reductionMaximumWithTensor: inputTensor
- axes: @[@0]
+ outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
+ axes:axes
name: nil];
} else if (reduction_type == MPSReductionType::MIN) {
- outputTensor = [mpsGraph reductionMinimumWithTensor: inputTensor
- axes: @[@0]
+ outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
+ axes:axes
name: nil];
}
@@ -1139,7 +1157,7 @@
});
}
- auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
+ auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -1175,6 +1193,7 @@
const Tensor& indices_t,
MPSReductionType reduction_type,
const std::string& func_name) {
+ TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
if (output_t.numel() == 0) {
return;
@@ -1240,7 +1259,7 @@
input_t.scalar_type() != ScalarType::Int &&
input_t.scalar_type() != ScalarType::Half) {
castInputTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeFloat32
+ toType:MPSDataTypeInt32
name:@"castInputTensor"];
} else {
castInputTensor = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index bc93e0d..3550a5a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1880,6 +1880,24 @@
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
+ def test_min_max(self):
+ def helper(dtype):
+ for _ in range(10):
+ if dtype == torch.float32 or dtype == torch.float16:
+ x = torch.randn((30, 15), device='mps', dtype=dtype)
+ else:
+ x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
+ x_cpu = x.to("cpu")
+
+ y = x.max()
+ y_cpu = x_cpu.max()
+ self.assertEqual(y, y_cpu)
+
+ z = x.min()
+ z_cpu = x_cpu.min()
+ self.assertEqual(z, z_cpu)
+
+ [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
class TestSmoothL1Loss(TestCase):