Let >> and << support half on CUDA (#37670)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37670

Differential Revision: D21395325

Pulled By: ngimel

fbshipit-source-id: fcb02f3bee488717cdc1ffc05204970b907d3c3f
diff --git a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
index 791bcc5..404cfdb 100644
--- a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
@@ -11,12 +11,14 @@
 
 
 void lshift_kernel_cuda(TensorIterator& iter) {
-  if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
-    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "lshift_cuda", [&]() {
+  if (iter.dtype() == ScalarType::Float ||
+      iter.dtype() == ScalarType::Double ||
+      iter.dtype() == ScalarType::Half) {
+    AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "lshift_cuda", [&]() {
       gpu_kernel_with_scalars(
         iter,
         []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
-          return a * std::pow((scalar_t)(2), b);
+          return a * std::pow(static_cast<scalar_t>(2), b);
       });
     });
   } else {
@@ -30,12 +32,14 @@
 }
 
 void rshift_kernel_cuda(TensorIterator& iter) {
-  if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
-    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rshift_cuda", [&]() {
+  if (iter.dtype() == ScalarType::Float ||
+      iter.dtype() == ScalarType::Double ||
+      iter.dtype() == ScalarType::Half) {
+    AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "rshift_cuda", [&]() {
       gpu_kernel_with_scalars(
         iter,
         []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
-          return a / std::pow((scalar_t)(2), b);
+          return a / std::pow(static_cast<scalar_t>(2), b);
       });
     });
   } else {
diff --git a/test/test_torch.py b/test/test_torch.py
index 26ef448..875a982 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -17575,11 +17575,11 @@
     ('__lshift__', '',
         lambda t, d: torch.pow(2, torch.arange(1, 5).to(dtype=_convert_t(t, d), device=d)),
         lambda t, d: [2],
-        1e-3, 1e-5, 1e-3, _signed_types_no_half, _cpu_types, False),
+        1e-3, 1e-5, 1e-3, _signed_types, _cpu_types, False),
     ('__rshift__', '',
         lambda t, d: torch.pow(2, torch.arange(3, 7).to(dtype=_convert_t(t, d), device=d)),
         lambda t, d: [2],
-        1e-3, 1e-5, 1e-3, _signed_types_no_half, _cpu_types, False),
+        1e-3, 1e-5, 1e-3, _signed_types, _cpu_types, False),
     # lapack tests
     ('qr', 'square', _small_2d, lambda t, d: [],
         1e-5, 1e-5, 3e-4, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]),