Fix type promotion for cosine_similarity() (#62054)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61454
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62054
Reviewed By: suo
Differential Revision: D29881755
Pulled By: jbschlosser
fbshipit-source-id: 10499766ac07b0ae3c0d2f4c426ea818d1e77db6
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index 843262a..2d3164d 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -237,11 +237,14 @@
Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
TORCH_CHECK(x1.sizes() == x2.sizes(), "cosine_similarity requires both inputs to have the same sizes, but x1 has ",
x1.sizes(), " and x2 has ", x2.sizes())
+ auto commonDtype = at::result_type(x1, x2);
+ Tensor x1_ = x1.to(commonDtype);
+ Tensor x2_ = x2.to(commonDtype);
// Follow scipy impl to improve numerical precision
// Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
- Tensor w12 = at::sum(x1 * x2, dim);
- Tensor w1 = at::sum(x1 * x1, dim);
- Tensor w2 = at::sum(x2 * x2, dim);
+ Tensor w12 = at::sum(x1_ * x2_, dim);
+ Tensor w1 = at::sum(x1_ * x1_, dim);
+ Tensor w2 = at::sum(x2_ * x2_, dim);
Tensor n12 = (w1 * w2).clamp_min_(eps * eps).sqrt_();
return w12.div_(n12);
}
diff --git a/test/test_nn.py b/test/test_nn.py
index 3e8d826..962386c 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -9382,6 +9382,11 @@
with self.assertRaises(RuntimeError):
F.cosine_similarity(input1, input2)
+ # Check type promotion, issue #61454
+ input = torch.tensor(12.)
+ out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
+ self.assertEqual(out, 1.)
+
def test_grid_sample_error_checking(self):
input = torch.empty(1, 1, 2, 2)
grid = torch.empty(1, 1, 1, 2)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 3621d70..70d69d7 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4243,6 +4243,8 @@
.. math ::
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
+Supports :ref:`type promotion <type-promotion-doc>`.
+
Args:
x1 (Tensor): First input.
x2 (Tensor): Second input (of size matching x1).