[MPS] Add lerp implementation (#105470)

lerp.Scalar fits very well into binary op template
Add a very naive implementation for `lerp.Tensor` as `add_out(self, weights.mul(end.sub(self)))`

Enable `lerp` testing in `test_mps`

Fixes https://github.com/pytorch/pytorch/issues/105382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105470
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index ca13f8d..2e05afb 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -18,6 +18,7 @@
 #include <ATen/ops/gt_native.h>
 #include <ATen/ops/hypot_native.h>
 #include <ATen/ops/le_native.h>
+#include <ATen/ops/lerp_native.h>
 #include <ATen/ops/logaddexp2_native.h>
 #include <ATen/ops/logaddexp_native.h>
 #include <ATen/ops/lt_native.h>
@@ -46,7 +47,7 @@
 #define BinaryOpFn(graph, primary, secondary) \
   MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
 
-// alpha is always 1.0 except when this function is called from add_sub_template()
+// alpha is always 1.0 except when this function is called from add_sub_lerp_template()
 void binaryOpTensor(const Tensor& self,
                     const Tensor& other,
                     const Scalar& alpha,
@@ -173,7 +174,7 @@
       feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
     }
 
-    // 'cachedGraph->alphaTensor' is not nil only if add_sub_template() was called with an alpha value != 1.0
+    // 'cachedGraph->alphaTensor' is not nil only if add_sub_lerp_template() was called with an alpha value != 1.0
     if (cachedGraph->alphaTensor) {
       alpha_scalar = getMPSScalar(alpha, other.scalar_type());
       feeds[cachedGraph->alphaTensor] = getMPSGraphTensorFromScalar(mpsStream, alpha_scalar);
@@ -255,11 +256,11 @@
                  div_mode_op_block);
 }
 
-void add_sub_template(const Tensor& self,
-                      const Tensor& other,
-                      const Scalar& alpha,
-                      const Tensor& output,
-                      std::string op_name) {
+void add_sub_lerp_template(const Tensor& self,
+                           const Tensor& other,
+                           const Scalar& alpha,
+                           const Tensor& output,
+                           std::string op_name) {
   if (alpha.toDouble() == 0.0) {
     if (!self.is_alias_of(output)) { // if inplace, no-op
       const_cast<Tensor&>(output) = self.clone();
@@ -273,10 +274,23 @@
     at::native::alpha_check(commonDtype, alpha);
   }
 
-  BinaryOpBlock add_sub_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
+  if (!alpha_has_value && op_name == "lerp") {
+    if (!self.is_alias_of(other)) { // if inplace, no-op
+      output.copy_(other);
+    }
+    return;
+  }
+
+  BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
     MPSGraph* mpsGraph = cachedGraph->graph();
     MPSGraphTensor* secondaryTensor = secondaryCastTensor;
 
+    if (op_name == "lerp") {
+      secondaryCastTensor = [mpsGraph subtractionWithPrimaryTensor:secondaryCastTensor
+                                                   secondaryTensor:primaryCastTensor
+                                                              name:nil];
+    }
+
     // if alpha is 1.0, then we don't bother adding another multiply to graph
     if (alpha_has_value) {
       cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[ @1 ]);
@@ -284,7 +298,7 @@
                                                   secondaryTensor:cachedGraph->alphaTensor
                                                              name:nil];
     }
-    if (op_name == "add")
+    if (op_name == "add" || op_name == "lerp")
       return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
     else
       return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
@@ -295,7 +309,7 @@
                  alpha,
                  output,
                  op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""),
-                 add_sub_op_block);
+                 add_sub_lerp_op_block);
 }
 
 } // namespace mps
@@ -389,11 +403,11 @@
 }
 
 TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
-  mps::add_sub_template(self, other, alpha, output, "add");
+  mps::add_sub_lerp_template(self, other, alpha, output, "add");
 }
 
 TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
-  mps::add_sub_template(self, other, alpha, output, "sub");
+  mps::add_sub_lerp_template(self, other, alpha, output, "sub");
 }
 
 TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const Tensor& out) {
@@ -492,4 +506,7 @@
   mps::binaryOpTensor(self, other, Scalar(1.0), output, "xlogy_out_mps", xlogy_op_block);
 }
 
+TORCH_IMPL_FUNC(lerp_Scalar_mps)(const Tensor& self, const Tensor& end, const Scalar& weight, const Tensor& out) {
+  mps::add_sub_lerp_template(self, end, weight, out, "lerp");
+}
 } // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/Lerp.mm b/aten/src/ATen/native/mps/operations/Lerp.mm
new file mode 100644
index 0000000..ca67433
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/Lerp.mm
@@ -0,0 +1,18 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include <ATen/core/Tensor.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#include <ATen/ops/add.h>
+#include <ATen/ops/lerp_native.h>
+#endif
+
+namespace at::native {
+TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Tensor& weight, const Tensor& out) {
+  // TODO: Write a much better implementation
+  at::add_out(const_cast<Tensor&>(out), self, weight.mul(end.sub(self)));
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e4a4673..a5a93e7 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9237,6 +9237,7 @@
   structured_inherits: TensorIteratorBase
   dispatch:
     CPU, CUDA: lerp_Scalar
+    MPS: lerp_Scalar_mps
   tags: pointwise
 
 - func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
@@ -9245,6 +9246,7 @@
   structured_inherits: TensorIteratorBase
   dispatch:
     CPU, CUDA: lerp_Tensor
+    MPS: lerp_Tensor_mps
   tags: pointwise
 
 - func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 8f16af4..211b5b4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -424,7 +424,6 @@
         'isposinf': None,
         'kthvalue': None,
         'lcm': None,
-        'lerp': None,
         'lgamma': None,
         'linalg.cholesky': None,
         'linalg.cholesky_ex': None,