[MPS] Add triangular solve op through MPSMatrixSolveTriangular (#94345)
Add triangular solve op support through MPS `MPSMatrixSolveTriangular` kernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94345
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 45dbb0a..d8389c1 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -1,17 +1,8 @@
// Copyright © 2022 Apple Inc.
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <ATen/mps/MPSStream.h>
-#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/mps/OperationUtils.h>
-#include <torch/library.h>
-
-#ifdef __OBJC__
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#endif
-
+#include <ATen/native/LinearAlgebraUtils.h>
+#include <ATen/native/Resize.h>
namespace at::native {
@@ -369,6 +360,7 @@
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
if (batch1.numel() == 0 || batch2.numel() == 0) {
+ result.zero_();
return result;
}
@@ -596,4 +588,105 @@
return addbmm_out_mps(self, batch1, batch2, beta, alpha, self);
}
+Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) {
+ using namespace mps;
+
+ checkInputsSolver(A, B, left, "linalg.solve_triangular");
+ Tensor A_t, B_t;
+ std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
+ at::native::resize_output(out, B_t.sizes());
+
+ if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) {
+ out.zero_();
+ return out;
+ }
+
+ Tensor A_ = A_t;
+ Tensor B_ = B_t;
+ if (!A_t.is_contiguous()) {
+ A_ = A_t.clone(at::MemoryFormat::Contiguous);
+ }
+ if (!B_t.is_contiguous()) {
+ B_ = B_t.clone(at::MemoryFormat::Contiguous);
+ }
+ id<MTLBuffer> aBuffer = getMTLBufferStorage(A_);
+ id<MTLBuffer> bBuffer = getMTLBufferStorage(B_);
+ id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
+ MPSStream* mpsStream = getCurrentMPSStream();
+ id<MTLDevice> device = MPSDevice::getInstance()->device();
+
+ dispatch_sync(mpsStream->queue(), ^(){
+ @autoreleasepool {
+ id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
+ uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1;
+ uint64_t aRows = A_.size(-2);
+ uint64_t bRows = B_.size(-2);
+ uint64_t aCols = A_.size(-1);
+ uint64_t bCols = B_.size(-1);
+ uint64_t aElemSize = A_.element_size();
+ uint64_t bElemSize = B_.element_size();
+
+ MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
+ right:!left
+ upper:upper
+ transpose:transpose
+ unit:unitriangular
+ order:left ? bRows : bCols
+ numberOfRightHandSides:left ? bCols : bRows
+ alpha:1.0f] autorelease];
+
+ MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
+ columns:aCols
+ matrices:batchSize
+ rowBytes:aCols * aElemSize
+ matrixBytes:aRows * aCols * aElemSize
+ dataType:getMPSDataType(A_.scalar_type())];
+ MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows
+ columns:bCols
+ matrices:batchSize
+ rowBytes:bCols * bElemSize
+ matrixBytes:bRows * bCols * bElemSize
+ dataType:getMPSDataType(B_.scalar_type())];
+ for (const auto i: c10::irange(batchSize)) {
+ const uint64_t aBatchOffset = i * aRows * aCols;
+ const uint64_t bBatchOffset = i * bRows * bCols;
+ MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
+ offset:(A_t.storage_offset() + aBatchOffset) * aElemSize
+ descriptor:sourceMatrixDesc] autorelease];
+ MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer
+ offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
+ descriptor:rightHandSideMatrixDesc] autorelease];
+ MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
+ offset:(out.storage_offset() + bBatchOffset) * bElemSize
+ descriptor:rightHandSideMatrixDesc] autorelease];
+
+ [filter encodeToCommandBuffer:commandBuffer
+ sourceMatrix:sourceMatrix
+ rightHandSideMatrix:rightHandSideMatrix
+ solutionMatrix:solutionMatrix];
+ }
+ mpsStream->commit(true);
+ }
+ });
+ return out;
+}
+
+Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) {
+ return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
+}
+
+Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular) {
+ Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
+ linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
+ return out;
+}
+
+TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
+ clone_A.copy_(A);
+ Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
+ linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out);
+ result.resize_(out.sizes());
+ result.copy_(out);
+}
+
} // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm
index e9469c4..a4b0db9 100644
--- a/aten/src/ATen/native/mps/operations/TriangularOps.mm
+++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm
@@ -19,6 +19,10 @@
const Tensor &output) {
using namespace mps;
+
+ if (self.numel() == 0) {
+ return;
+ }
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
@@ -98,6 +102,10 @@
const Tensor &output) {
using namespace mps;
+
+ if (self.numel() == 0) {
+ return;
+ }
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index bec46a0..e7b25c8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8668,6 +8668,7 @@
structured: True
dispatch:
CPU, CUDA: triangular_solve_out
+ MPS: triangular_solve_mps_out
SparseCsrCPU: triangular_solve_out_sparse_csr_cpu
SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda
@@ -8683,12 +8684,14 @@
python_module: linalg
dispatch:
CPU, CUDA: linalg_solve_triangular_out
+ MPS: linalg_solve_triangular_mps_out
- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_solve_triangular
+ MPS: linalg_solve_triangular_mps
- func: linalg_vander(Tensor x, *, int? N=None) -> Tensor
python_module: linalg
diff --git a/test/test_mps.py b/test/test_mps.py
index 3b3bf9e..79256e7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8528,6 +8528,8 @@
'native_layer_norm': ['torch.float32'],
'nn.functional.layer_norm': ['torch.float32'],
'nn.functional.bilinear': ['f32'],
+ 'linalg.solve_triangular': ['f32'],
+ 'triangular_solve': ['f32'],
}
@@ -8704,7 +8706,9 @@
'view_as': ['f16', 'f32'],
'vsplit': ['f16', 'f32'],
'vstack': ['f16', 'f32'],
- 'zero_': ['f16', 'f32']
+ 'zero_': ['f16', 'f32'],
+ 'linalg.solve_triangular': ['f32'],
+ 'triangular_solve': ['f32'],
}
# These ops that are problematic. So never run them even when