Revert "MPS: add ranked tensors for addcmul ops instead of constants and update version_check (#78312)"
This reverts commit 59b6052dadc3bd60e079a493c72286c037be28ae.
Reverted https://github.com/pytorch/pytorch/pull/78312 on behalf of https://github.com/malfet due to , as it contains syntactic error that breaks Mac builds, see https://hud.pytorch.org/pytorch/pytorch/commit/59b6052dadc3bd60e079a493c72286c037be28ae
diff --git a/aten/src/ATen/mps/EmptyTensor.cpp b/aten/src/ATen/mps/EmptyTensor.cpp
index 759aef7..21f47e1 100644
--- a/aten/src/ATen/mps/EmptyTensor.cpp
+++ b/aten/src/ATen/mps/EmptyTensor.cpp
@@ -25,7 +25,7 @@
c10::optional<c10::MemoryFormat> memory_format_opt) {
#if defined(__APPLE__)
#if __is_target_os(macOS)
- if (at::hasMPS()) {
+ if (__builtin_available(macOS 12.3, *) || __builtin_available(macOSApplicationExtension 12.3, *)) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::MPS);
@@ -87,7 +87,7 @@
c10::optional<Device> device_opt) {
#if defined(__APPLE__)
#if __is_target_os(macOS)
- if (at::hasMPS()) {
+ if (__builtin_available(macOS 12.3, *) || __builtin_available(macOSApplicationExtension 12.3, *)) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_mps());
TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm
index ecf0755..2e89dfa 100644
--- a/aten/src/ATen/mps/MPSDevice.mm
+++ b/aten/src/ATen/mps/MPSDevice.mm
@@ -22,17 +22,10 @@
MPSDevice::MPSDevice(): _mtl_device(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
- // Create the MPSGraph and check method introduced in 12.3+
- // which is used by MPS backend.
- id mpsCD = NSClassFromString(@"MPSGraph");
- if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
- recurrentWeight:
- inputWeight:
- bias:
- initState:
- initCell:
- descriptor:
- name:)] == NO)) {
+ id mpsCD = NSClassFromString(@"MPSGraphCompilationDescriptor");
+ if (![mpsCD instancesRespondToSelector:@selector(optimizationLevel)]) {
+ // According to https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/3922624-optimizationlevel
+ // this means we are running on older MacOS
return;
}
NSArray* devices = [MTLCopyAllDevices() autorelease];
diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
index 66427c7..f5f6be0 100644
--- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm
+++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
@@ -29,7 +29,7 @@
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
- string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false);
+ string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false) + ":" + getMPSTypeString(value_opt.type());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
@@ -44,7 +44,7 @@
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
- newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()));
+ newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(value_opt.type()));
// the tensor to be optionally multiplied by value_scalar
MPSGraphTensor *multiplicandTensor = nil;
@@ -81,7 +81,7 @@
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
tensor1Placeholder.getMPSGraphTensor() : tensor1Placeholder.getMPSGraphTensorData(),
tensor2Placeholder.getMPSGraphTensor() : tensor2Placeholder.getMPSGraphTensorData(),
- cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_opt, getMPSScalarType(self.scalar_type())),
+ cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_opt, getMPSScalarType(value_opt.type())),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{