[MPS] Fixes for Binary ops with casting issues from FP to uint8 (#94382)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94382
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 17aa58d..1e47b57 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -1,17 +1,7 @@
// Copyright © 2022 Apple Inc.
-#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/Copy.h>
#include <ATen/native/mps/OperationUtils.h>
-#include <iostream>
-#include <cstring>
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <torch/library.h>
-#include <ATen/native/Resize.h>
-#include <c10/util/Optional.h>
-
namespace at::native {
namespace mps {
@@ -84,7 +74,11 @@
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
- MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputTensor toType:dstDType name:@"cast"];
+ MPSGraphTensor* inputCastTensor = inputTensor;
+ if (isFloatingType(src.scalar_type()) && dstDType == MPSDataTypeUInt8) {
+ inputCastTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"cast"];
+ }
+ MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputCastTensor toType:dstDType name:@"cast"];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index 2c344ad..68a95b8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8570,8 +8570,10 @@
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'bmm': ['f32'],
'broadcast_shapes': ['f32'],
+ 'byte': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'ceil': ['f32', 'int32', 'int64', 'f16'],
- 'char': ['b8', 'u8'],
+ 'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'],
'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8607,17 +8609,19 @@
'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
- 'float': ['f32'],
+ 'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
'floor_divide': ['f32', 'f16'],
'frac': ['f16', 'f32'],
'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'gradient': ['f16', 'f32', 'i16'],
- 'half': ['f16'],
+ 'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'index_add': ['f16', 'f32', 'i16', 'i32'],
- 'int': ['i32'],
+ 'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8724,7 +8728,7 @@
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
- 'short': ['i16'],
+ 'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],