[MPS] Exclude int64 dtype from reduction ops (#91272)

Reduction ops don't support int64 data type. This PR takes care to assert when int64 is used for min / max reductions ops.
All other integer dtypes are casted to int32.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91272
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 684837f..43b3336 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -34,6 +34,15 @@
 
 using namespace mps;
 
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
+  int64_t ndim = t.dim();
+  auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+  for (const auto i: c10::irange(ndim)) {
+    axes[i] = [NSNumber numberWithInteger:i];
+  }
+  return axes;
+}
+
 void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
                          NSMutableArray<NSNumber*> * &apparent_in_shape,
                          int64_t num_reduce_dims,
@@ -1091,15 +1100,12 @@
 Tensor min_max_mps(const Tensor& input_t,
                    MPSReductionType reduction_type,
                    const std::string& func_name) {
+  TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
   using CachedGraph = MPSUnaryCachedGraph;
 
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
   IntArrayRef input_shape = input_t.sizes();
-
-  // Flatten the input tensor to reduce it to one value
-  NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
   int64_t num_in_elements = c10::multiply_integers(input_shape);
-  apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];
 
   Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
 
@@ -1121,14 +1127,26 @@
           MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
 
           MPSGraphTensor* outputTensor = nil;
+          MPSGraphTensor* castInputTensor = nil;
 
+          if (input_t.scalar_type() != ScalarType::Float &&
+              input_t.scalar_type() != ScalarType::Int   &&
+              input_t.scalar_type() != ScalarType::Half) {
+            castInputTensor =  [mpsGraph castTensor:inputTensor
+                                             toType:MPSDataTypeInt32
+                                               name:@"castInputTensor"];
+          } else {
+            castInputTensor = inputTensor;
+          }
+
+          NSArray<NSNumber*>* axes = getTensorAxes(input_t);
           if (reduction_type == MPSReductionType::MAX) {
-            outputTensor = [mpsGraph reductionMaximumWithTensor: inputTensor
-                                                           axes: @[@0]
+            outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
+                                                           axes:axes
                                                            name: nil];
           } else if (reduction_type == MPSReductionType::MIN) {
-            outputTensor = [mpsGraph reductionMinimumWithTensor: inputTensor
-                                                           axes: @[@0]
+            outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
+                                                           axes:axes
                                                            name: nil];
           }
 
@@ -1139,7 +1157,7 @@
       });
     }
 
-    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
+    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
     auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);
 
     NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -1175,6 +1193,7 @@
                      const Tensor& indices_t,
                      MPSReductionType reduction_type,
                      const std::string& func_name) {
+    TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
 
   if (output_t.numel() == 0) {
     return;
@@ -1240,7 +1259,7 @@
               input_t.scalar_type() != ScalarType::Int   &&
               input_t.scalar_type() != ScalarType::Half) {
             castInputTensor =  [mpsGraph castTensor:inputTensor
-                                             toType:MPSDataTypeFloat32
+                                                 toType:MPSDataTypeInt32
                                                name:@"castInputTensor"];
           } else {
             castInputTensor = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index bc93e0d..3550a5a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1880,6 +1880,24 @@
         helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
         helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
 
+    def test_min_max(self):
+        def helper(dtype):
+            for _ in range(10):
+                if dtype == torch.float32 or dtype == torch.float16:
+                    x = torch.randn((30, 15), device='mps', dtype=dtype)
+                else:
+                    x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
+                x_cpu = x.to("cpu")
+
+                y = x.max()
+                y_cpu = x_cpu.max()
+                self.assertEqual(y, y_cpu)
+
+                z = x.min()
+                z_cpu = x_cpu.min()
+                self.assertEqual(z, z_cpu)
+
+        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
 
 class TestSmoothL1Loss(TestCase):