[MPS] Add lerp implementation (#105470)
lerp.Scalar fits very well into binary op template
Add a very naive implementation for `lerp.Tensor` as `add_out(self, weights.mul(end.sub(self)))`
Enable `lerp` testing in `test_mps`
Fixes https://github.com/pytorch/pytorch/issues/105382
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105470
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index ca13f8d..2e05afb 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -18,6 +18,7 @@
#include <ATen/ops/gt_native.h>
#include <ATen/ops/hypot_native.h>
#include <ATen/ops/le_native.h>
+#include <ATen/ops/lerp_native.h>
#include <ATen/ops/logaddexp2_native.h>
#include <ATen/ops/logaddexp_native.h>
#include <ATen/ops/lt_native.h>
@@ -46,7 +47,7 @@
#define BinaryOpFn(graph, primary, secondary) \
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
-// alpha is always 1.0 except when this function is called from add_sub_template()
+// alpha is always 1.0 except when this function is called from add_sub_lerp_template()
void binaryOpTensor(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
@@ -173,7 +174,7 @@
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
}
- // 'cachedGraph->alphaTensor' is not nil only if add_sub_template() was called with an alpha value != 1.0
+ // 'cachedGraph->alphaTensor' is not nil only if add_sub_lerp_template() was called with an alpha value != 1.0
if (cachedGraph->alphaTensor) {
alpha_scalar = getMPSScalar(alpha, other.scalar_type());
feeds[cachedGraph->alphaTensor] = getMPSGraphTensorFromScalar(mpsStream, alpha_scalar);
@@ -255,11 +256,11 @@
div_mode_op_block);
}
-void add_sub_template(const Tensor& self,
- const Tensor& other,
- const Scalar& alpha,
- const Tensor& output,
- std::string op_name) {
+void add_sub_lerp_template(const Tensor& self,
+ const Tensor& other,
+ const Scalar& alpha,
+ const Tensor& output,
+ std::string op_name) {
if (alpha.toDouble() == 0.0) {
if (!self.is_alias_of(output)) { // if inplace, no-op
const_cast<Tensor&>(output) = self.clone();
@@ -273,10 +274,23 @@
at::native::alpha_check(commonDtype, alpha);
}
- BinaryOpBlock add_sub_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
+ if (!alpha_has_value && op_name == "lerp") {
+ if (!self.is_alias_of(other)) { // if inplace, no-op
+ output.copy_(other);
+ }
+ return;
+ }
+
+ BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* secondaryTensor = secondaryCastTensor;
+ if (op_name == "lerp") {
+ secondaryCastTensor = [mpsGraph subtractionWithPrimaryTensor:secondaryCastTensor
+ secondaryTensor:primaryCastTensor
+ name:nil];
+ }
+
// if alpha is 1.0, then we don't bother adding another multiply to graph
if (alpha_has_value) {
cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[ @1 ]);
@@ -284,7 +298,7 @@
secondaryTensor:cachedGraph->alphaTensor
name:nil];
}
- if (op_name == "add")
+ if (op_name == "add" || op_name == "lerp")
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
else
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
@@ -295,7 +309,7 @@
alpha,
output,
op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""),
- add_sub_op_block);
+ add_sub_lerp_op_block);
}
} // namespace mps
@@ -389,11 +403,11 @@
}
TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
- mps::add_sub_template(self, other, alpha, output, "add");
+ mps::add_sub_lerp_template(self, other, alpha, output, "add");
}
TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
- mps::add_sub_template(self, other, alpha, output, "sub");
+ mps::add_sub_lerp_template(self, other, alpha, output, "sub");
}
TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const Tensor& out) {
@@ -492,4 +506,7 @@
mps::binaryOpTensor(self, other, Scalar(1.0), output, "xlogy_out_mps", xlogy_op_block);
}
+TORCH_IMPL_FUNC(lerp_Scalar_mps)(const Tensor& self, const Tensor& end, const Scalar& weight, const Tensor& out) {
+ mps::add_sub_lerp_template(self, end, weight, out, "lerp");
+}
} // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/Lerp.mm b/aten/src/ATen/native/mps/operations/Lerp.mm
new file mode 100644
index 0000000..ca67433
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/Lerp.mm
@@ -0,0 +1,18 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include <ATen/core/Tensor.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#include <ATen/ops/add.h>
+#include <ATen/ops/lerp_native.h>
+#endif
+
+namespace at::native {
+TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Tensor& weight, const Tensor& out) {
+ // TODO: Write a much better implementation
+ at::add_out(const_cast<Tensor&>(out), self, weight.mul(end.sub(self)));
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e4a4673..a5a93e7 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9237,6 +9237,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: lerp_Scalar
+ MPS: lerp_Scalar_mps
tags: pointwise
- func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
@@ -9245,6 +9246,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: lerp_Tensor
+ MPS: lerp_Tensor_mps
tags: pointwise
- func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 8f16af4..211b5b4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -424,7 +424,6 @@
'isposinf': None,
'kthvalue': None,
'lcm': None,
- 'lerp': None,
'lgamma': None,
'linalg.cholesky': None,
'linalg.cholesky_ex': None,