[MPS] Add triangular solve op through MPSMatrixSolveTriangular (#94345)

Add triangular solve op support through MPS `MPSMatrixSolveTriangular` kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94345
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 45dbb0a..d8389c1 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -1,17 +1,8 @@
 //  Copyright © 2022 Apple Inc.
 
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <ATen/mps/MPSStream.h>
-#include <ATen/native/LinearAlgebraUtils.h>
 #include <ATen/native/mps/OperationUtils.h>
-#include <torch/library.h>
-
-#ifdef __OBJC__
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#endif
-
+#include <ATen/native/LinearAlgebraUtils.h>
+#include <ATen/native/Resize.h>
 
 namespace at::native {
 
@@ -369,6 +360,7 @@
               || batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
 
   if (batch1.numel() == 0 || batch2.numel() == 0) {
+    result.zero_();
     return result;
   }
 
@@ -596,4 +588,105 @@
   return addbmm_out_mps(self, batch1, batch2, beta, alpha, self);
 }
 
+Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) {
+  using namespace mps;
+
+  checkInputsSolver(A, B, left, "linalg.solve_triangular");
+  Tensor A_t, B_t;
+  std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
+  at::native::resize_output(out, B_t.sizes());
+
+  if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) {
+    out.zero_();
+    return out;
+  }
+
+  Tensor A_ = A_t;
+  Tensor B_ = B_t;
+  if (!A_t.is_contiguous()) {
+    A_ = A_t.clone(at::MemoryFormat::Contiguous);
+  }
+  if (!B_t.is_contiguous()) {
+    B_ = B_t.clone(at::MemoryFormat::Contiguous);
+  }
+  id<MTLBuffer> aBuffer = getMTLBufferStorage(A_);
+  id<MTLBuffer> bBuffer = getMTLBufferStorage(B_);
+  id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
+  MPSStream* mpsStream = getCurrentMPSStream();
+  id<MTLDevice> device = MPSDevice::getInstance()->device();
+
+  dispatch_sync(mpsStream->queue(), ^(){
+    @autoreleasepool {
+      id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
+      uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1;
+      uint64_t aRows = A_.size(-2);
+      uint64_t bRows = B_.size(-2);
+      uint64_t aCols = A_.size(-1);
+      uint64_t bCols = B_.size(-1);
+      uint64_t aElemSize = A_.element_size();
+      uint64_t bElemSize = B_.element_size();
+
+      MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
+                                                                                     right:!left
+                                                                                     upper:upper
+                                                                                 transpose:transpose
+                                                                                      unit:unitriangular
+                                                                                     order:left ? bRows : bCols
+                                                                    numberOfRightHandSides:left ? bCols : bRows
+                                                                                     alpha:1.0f] autorelease];
+
+      MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
+                                                                                    columns:aCols
+                                                                                   matrices:batchSize
+                                                                                   rowBytes:aCols * aElemSize
+                                                                                matrixBytes:aRows * aCols * aElemSize
+                                                                                   dataType:getMPSDataType(A_.scalar_type())];
+      MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows
+                                                                                           columns:bCols
+                                                                                          matrices:batchSize
+                                                                                          rowBytes:bCols * bElemSize
+                                                                                       matrixBytes:bRows * bCols * bElemSize
+                                                                                          dataType:getMPSDataType(B_.scalar_type())];
+      for (const auto i: c10::irange(batchSize)) {
+        const uint64_t aBatchOffset = i * aRows * aCols;
+        const uint64_t bBatchOffset = i * bRows * bCols;
+        MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
+                                                              offset:(A_t.storage_offset() + aBatchOffset) * aElemSize
+                                                          descriptor:sourceMatrixDesc] autorelease];
+        MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer
+                                                                     offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
+                                                                 descriptor:rightHandSideMatrixDesc] autorelease];
+        MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
+                                                                offset:(out.storage_offset() + bBatchOffset) * bElemSize
+                                                            descriptor:rightHandSideMatrixDesc] autorelease];
+
+        [filter encodeToCommandBuffer:commandBuffer
+                         sourceMatrix:sourceMatrix
+                  rightHandSideMatrix:rightHandSideMatrix
+                       solutionMatrix:solutionMatrix];
+      }
+      mpsStream->commit(true);
+    }
+  });
+  return out;
+}
+
+Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) {
+  return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
+}
+
+Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular) {
+  Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
+  linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
+  return out;
+}
+
+TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
+  clone_A.copy_(A);
+  Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
+  linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out);
+  result.resize_(out.sizes());
+  result.copy_(out);
+}
+
 } // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm
index e9469c4..a4b0db9 100644
--- a/aten/src/ATen/native/mps/operations/TriangularOps.mm
+++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm
@@ -19,6 +19,10 @@
  const Tensor &output) {
 
   using namespace mps;
+
+  if (self.numel() == 0) {
+    return;
+  }
   MPSStream* stream = getCurrentMPSStream();
 
   // Derive from MPSCachedGraph
@@ -98,6 +102,10 @@
  const Tensor &output) {
 
   using namespace mps;
+
+  if (self.numel() == 0) {
+    return;
+  }
   MPSStream* stream = getCurrentMPSStream();
 
   // Derive from MPSCachedGraph
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index bec46a0..e7b25c8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8668,6 +8668,7 @@
   structured: True
   dispatch:
     CPU, CUDA: triangular_solve_out
+    MPS: triangular_solve_mps_out
     SparseCsrCPU: triangular_solve_out_sparse_csr_cpu
     SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda
 
@@ -8683,12 +8684,14 @@
   python_module: linalg
   dispatch:
     CPU, CUDA: linalg_solve_triangular_out
+    MPS: linalg_solve_triangular_mps_out
 
 - func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
   python_module: linalg
   variants: function
   dispatch:
     CPU, CUDA: linalg_solve_triangular
+    MPS: linalg_solve_triangular_mps
 
 - func: linalg_vander(Tensor x, *, int? N=None) -> Tensor
   python_module: linalg
diff --git a/test/test_mps.py b/test/test_mps.py
index 3b3bf9e..79256e7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8528,6 +8528,8 @@
         'native_layer_norm': ['torch.float32'],
         'nn.functional.layer_norm': ['torch.float32'],
         'nn.functional.bilinear': ['f32'],
+        'linalg.solve_triangular': ['f32'],
+        'triangular_solve': ['f32'],
     }
 
 
@@ -8704,7 +8706,9 @@
         'view_as': ['f16', 'f32'],
         'vsplit': ['f16', 'f32'],
         'vstack': ['f16', 'f32'],
-        'zero_': ['f16', 'f32']
+        'zero_': ['f16', 'f32'],
+        'linalg.solve_triangular': ['f32'],
+        'triangular_solve': ['f32'],
     }
 
     # These ops that are problematic. So never run them even when