[MPS] Fixes for Binary ops with casting issues from FP to uint8 (#94382)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94382
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 17aa58d..1e47b57 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -1,17 +1,7 @@
 //  Copyright © 2022 Apple Inc.
 
-#include <ATen/mps/MPSStream.h>
 #include <ATen/native/mps/Copy.h>
 #include <ATen/native/mps/OperationUtils.h>
-#include <iostream>
-#include <cstring>
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <torch/library.h>
-#include <ATen/native/Resize.h>
-#include <c10/util/Optional.h>
-
 
 namespace at::native {
 namespace mps {
@@ -84,7 +74,11 @@
           newCachedGraph = new CachedGraph(mpsGraph);
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
-          MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputTensor toType:dstDType name:@"cast"];
+          MPSGraphTensor* inputCastTensor = inputTensor;
+          if (isFloatingType(src.scalar_type()) && dstDType == MPSDataTypeUInt8) {
+            inputCastTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"cast"];
+          }
+          MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputCastTensor toType:dstDType name:@"cast"];
 
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index 2c344ad..68a95b8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8570,8 +8570,10 @@
         'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
         'bmm': ['f32'],
         'broadcast_shapes': ['f32'],
+        'byte': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'ceil': ['f32', 'int32', 'int64', 'f16'],
-        'char': ['b8', 'u8'],
+        'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'],
         'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8607,17 +8609,19 @@
         'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
-        'float': ['f32'],
+        'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
         'floor_divide': ['f32', 'f16'],
         'frac': ['f16', 'f32'],
         'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'gradient': ['f16', 'f32', 'i16'],
-        'half': ['f16'],
+        'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'index_add': ['f16', 'f32', 'i16', 'i32'],
-        'int': ['i32'],
+        'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8724,7 +8728,7 @@
         'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
         'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
-        'short': ['i16'],
+        'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
         'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
         'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],