Fix type promotion for cosine_similarity() (#62054)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/61454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62054

Reviewed By: suo

Differential Revision: D29881755

Pulled By: jbschlosser

fbshipit-source-id: 10499766ac07b0ae3c0d2f4c426ea818d1e77db6
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index 843262a..2d3164d 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -237,11 +237,14 @@
 Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
   TORCH_CHECK(x1.sizes() == x2.sizes(), "cosine_similarity requires both inputs to have the same sizes, but x1 has ",
               x1.sizes(), " and x2 has ", x2.sizes())
+  auto commonDtype = at::result_type(x1, x2);
+  Tensor x1_ = x1.to(commonDtype);
+  Tensor x2_ = x2.to(commonDtype);
   // Follow scipy impl to improve numerical precision
   // Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
-  Tensor w12 = at::sum(x1 * x2, dim);
-  Tensor w1 = at::sum(x1 * x1, dim);
-  Tensor w2 = at::sum(x2 * x2, dim);
+  Tensor w12 = at::sum(x1_ * x2_, dim);
+  Tensor w1 = at::sum(x1_ * x1_, dim);
+  Tensor w2 = at::sum(x2_ * x2_, dim);
   Tensor n12 = (w1 * w2).clamp_min_(eps * eps).sqrt_();
   return w12.div_(n12);
 }
diff --git a/test/test_nn.py b/test/test_nn.py
index 3e8d826..962386c 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -9382,6 +9382,11 @@
         with self.assertRaises(RuntimeError):
             F.cosine_similarity(input1, input2)
 
+        # Check type promotion, issue #61454
+        input = torch.tensor(12.)
+        out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
+        self.assertEqual(out, 1.)
+
     def test_grid_sample_error_checking(self):
         input = torch.empty(1, 1, 2, 2)
         grid = torch.empty(1, 1, 1, 2)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 3621d70..70d69d7 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4243,6 +4243,8 @@
 .. math ::
     \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
 
+Supports :ref:`type promotion <type-promotion-doc>`.
+
 Args:
     x1 (Tensor): First input.
     x2 (Tensor): Second input (of size matching x1).