[numpy] `torch.rsqrt` : promote integer inputs to float (#47909)

Summary:
Reference https://github.com/pytorch/pytorch/issues/42515

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47909

Reviewed By: ngimel

Differential Revision: D25730876

Pulled By: mruberry

fbshipit-source-id: c87a8f686e1dd64e511640e0278021c4a584ccf2
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index e6dd1bc..0f6da7e4 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -326,8 +326,12 @@
 Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); }
 Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); }
 
-Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); }
-Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); }
+Tensor& rsqrt_out(Tensor& result, const Tensor& self) {
+  return unary_op_impl_float_out(result, self, rsqrt_stub);
+}
+Tensor rsqrt(const Tensor& self) {
+  return unary_op_impl_float(self, rsqrt_stub);
+}
 Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); }
 
 Tensor& sign_out(Tensor& result, const Tensor& self) {
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 049b3ef..32ebaf7 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -587,10 +587,10 @@
 }
 
 static void rsqrt_kernel(TensorIterator& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "rsqrt_cpu", [&] {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "rsqrt_cpu", [&] {
     cpu_kernel_vec(
         iter,
-        [=](scalar_t a) -> scalar_t {
+        [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
           return (static_cast<scalar_t>(1)) / std::sqrt(a);
         },
         [=](Vec256<scalar_t> a) { return a.rsqrt(); });
diff --git a/test/test_torch.py b/test/test_torch.py
index 6532c2e..8872516 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6870,7 +6870,6 @@
     ('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
     ('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
     ('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False),
-    ('rsqrt', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-2, 1e-5, 1e-4, _float_types_no_half),
     ('sinh', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types),
     ('tan', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types),
     ('tan', 'complex', lambda t, d: _small_3d(t, d), lambda t, d: [], 1e-3, 1e-5, 1e-5, _complex_types),
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
index 7764823..1daecc2 100644
--- a/test/test_unary_ufuncs.py
+++ b/test/test_unary_ufuncs.py
@@ -1715,7 +1715,6 @@
     _TorchMathTestMeta('ceil'),
     _TorchMathTestMeta('rad2deg'),
     _TorchMathTestMeta('deg2rad'),
-    _TorchMathTestMeta('rsqrt', reffn=lambda x: np.reciprocal(np.sqrt(x))),
     _TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)),
     _TorchMathTestMeta('trunc'),
     _TorchMathTestMeta('round'),
diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp
index e60a0bd..186af3c 100644
--- a/torch/csrc/jit/tensorexpr/eval.cpp
+++ b/torch/csrc/jit/tensorexpr/eval.cpp
@@ -834,8 +834,12 @@
         return std::erfc(v);
       case kSqrt:
         return std::sqrt(v);
-      case kRsqrt:
-        return 1.0f / std::sqrt(v);
+      case kRsqrt: {
+        auto rsqrt = [](TInput v) __ubsan_ignore_float_divide_by_zero__ {
+          return 1.0f / std::sqrt(v);
+        };
+        return rsqrt(v);
+      }
       case kCeil:
         return std::ceil(v);
       case kFloor:
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 999186d..0145014 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -1282,8 +1282,9 @@
     } break;
 
     case aten::rsqrt: {
-      return computeOneOperand(
-          "aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); });
+      return computeOneOperand("aten_rsqrt", v, [](const ExprHandle& a) {
+        return rsqrt(promoteIntegerToDefaultType(a));
+      });
     } break;
 
     case aten::abs: {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 87d0baa..0c5a5a6 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1001,6 +1001,16 @@
                        SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
                                 dtypes=[torch.bfloat16]),
                    )),
+    UnaryUfuncInfo('rsqrt',
+                   ref=lambda x: np.reciprocal(np.sqrt(x)),
+                   domain=(0, float('inf')),
+                   dtypes=all_types_and_complex_and(torch.bool),
+                   dtypesIfCPU=all_types_and_complex_and(torch.bool),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
+                   decorators=(precisionOverride({torch.half: 5e-2}),),
+                   promotes_integers_to_float=True,
+                   assert_autodiffed=True,
+                   handles_complex_extremals=False),
     UnaryUfuncInfo('sqrt',
                    ref=np.sqrt,
                    domain=(0, float('inf')),
@@ -1466,6 +1476,10 @@
         ('ceil', (), NO_ARGS, 'scalar', (True,)),
         ('rad2deg', (S, S, S), NO_ARGS),
         ('deg2rad', (S, S, S), NO_ARGS),
+        # Removing the 'rsqrt' entries leads to failure in
+        # test_index_fill_variable_dim_*
+        # TODO: Remove when fixed.
+        # Reference: https://github.com/pytorch/pytorch/issues/48230
         ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
         ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
         ('rsqrt', torch.rand(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),