[numpy] `torch.a{cos, tan}` : promote integer inputs to float (#47005)

Summary:
Reference https://github.com/pytorch/pytorch/issues/42515

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47005

Reviewed By: mrshenli

Differential Revision: D24681097

Pulled By: mruberry

fbshipit-source-id: 2f29655a5f3871ee96c2bfd35c93f4d721730e37
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index 4340a80..8c3fc18 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -117,8 +117,8 @@
   return out_impl(self, self);
 }
 
-Tensor& acos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acos_stub); }
-Tensor acos(const Tensor& self) { return unary_op_impl(self, at::acos_out); }
+Tensor& acos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acos_stub); }
+Tensor acos(const Tensor& self) { return unary_op_impl_float(self, acos_stub); }
 Tensor& acos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); }
 
 // arccos, alias for acos
@@ -158,8 +158,8 @@
 Tensor arcsin(const Tensor& self) { return self.asin(); }
 Tensor& arcsin_(Tensor& self) { return self.asin_(); }
 
-Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); }
-Tensor atan(const Tensor& self) { return unary_op_impl(self, at::atan_out); }
+Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, atan_stub); }
+Tensor atan(const Tensor& self) { return unary_op_impl_float(self, atan_stub); }
 Tensor& atan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); }
 
 // arctan, alias of atan
diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
index 773d650..6bf0bdc 100644
--- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
+++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
@@ -11,7 +11,7 @@
 namespace at { namespace native {
 
 void acos_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "acos_cuda", [&]() {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "acos_cuda", [&]() {
     gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
       return ::acos(a);
     });
@@ -27,7 +27,7 @@
 }
 
 void atan_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "atan_cuda", [&]() {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "atan_cuda", [&]() {
     gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
       return ::atan(a);
     });
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 4787169..7528cfe 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -1091,8 +1091,9 @@
     } break;
 
     case aten::acos: {
-      return computeOneOperand(
-          "aten_acos", v, [](const ExprHandle& a) { return acos(a); });
+      return computeOneOperand("aten_acos", v, [](const ExprHandle& a) {
+        return acos(promoteIntegerToFloat(a));
+      });
     } break;
 
     case aten::asin: {
@@ -1111,8 +1112,9 @@
     } break;
 
     case aten::atan: {
-      return computeOneOperand(
-          "aten_atan", v, [](const ExprHandle& a) { return atan(a); });
+      return computeOneOperand("aten_atan", v, [](const ExprHandle& a) {
+        return atan(promoteIntegerToFloat(a));
+      });
     } break;
 
     case aten::atan2: {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4c87e3a..b588ded 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -211,9 +211,13 @@
                    ref=np.arccos,
                    domain=(-1, 1),
                    handles_complex_extremals=False,
+                   dtypes=all_types_and_complex_and(torch.bool),
+                   dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
                    decorators=(precisionOverride({torch.float16: 1e-2,
                                                   torch.bfloat16: 1e-1,
                                                   torch.complex64: 1e-2}),),
+                   promotes_integers_to_float=True,
                    skips=(
                        SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
                                 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
@@ -257,7 +261,11 @@
                    test_inplace_grad=False),
     UnaryUfuncInfo('atan',
                    ref=np.arctan,
+                   dtypes=all_types_and_complex_and(torch.bool),
+                   dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
                    decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+                   promotes_integers_to_float=True,
                    skips=(
                        SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
                                 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),