[MPS] Add hypot op (#95196)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95196
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 6569e59..b87dab0 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -370,6 +370,25 @@
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
}
+TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
+{
+ mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
+ MPSGraph* mpsGraph = cachedGraph->graph();
+ MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0
+ shape:@[@1]
+ dataType:primaryCastTensor.dataType];
+ MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
+ secondaryTensor:twoTensor
+ name:nil]
+ secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
+ secondaryTensor:twoTensor
+ name:nil]
+ name:nil];
+ return [mpsGraph squareRootWithTensor:sumTensor name:nil];
+ };
+ mps::binaryOpTensor(self, other, Scalar(1.0), output, "hypot_out_mps", hypot_op_block);
+}
+
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2f7a1a8..3772bb5 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9164,6 +9164,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: hypot_out
+ MPS: hypot_out_mps
tags: pointwise
- func: hypot(Tensor self, Tensor other) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index e355ac9..1e1f217 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9253,6 +9253,7 @@
'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'hypot': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'index_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -9510,6 +9511,7 @@
'gradient': ['f32'],
'half': ['f16'],
'hstack': ['f16', 'f32'],
+ 'hypot': ['f16', 'f32'],
'index_select': ['f16', 'f32'],
'index_add': ['f16', 'f32'],
'isclose': ['f16', 'f32'],