Let >> and << support half on CUDA (#37670)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37670
Differential Revision: D21395325
Pulled By: ngimel
fbshipit-source-id: fcb02f3bee488717cdc1ffc05204970b907d3c3f
diff --git a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
index 791bcc5..404cfdb 100644
--- a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu
@@ -11,12 +11,14 @@
void lshift_kernel_cuda(TensorIterator& iter) {
- if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "lshift_cuda", [&]() {
+ if (iter.dtype() == ScalarType::Float ||
+ iter.dtype() == ScalarType::Double ||
+ iter.dtype() == ScalarType::Half) {
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
- return a * std::pow((scalar_t)(2), b);
+ return a * std::pow(static_cast<scalar_t>(2), b);
});
});
} else {
@@ -30,12 +32,14 @@
}
void rshift_kernel_cuda(TensorIterator& iter) {
- if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rshift_cuda", [&]() {
+ if (iter.dtype() == ScalarType::Float ||
+ iter.dtype() == ScalarType::Double ||
+ iter.dtype() == ScalarType::Half) {
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
- return a / std::pow((scalar_t)(2), b);
+ return a / std::pow(static_cast<scalar_t>(2), b);
});
});
} else {
diff --git a/test/test_torch.py b/test/test_torch.py
index 26ef448..875a982 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -17575,11 +17575,11 @@
('__lshift__', '',
lambda t, d: torch.pow(2, torch.arange(1, 5).to(dtype=_convert_t(t, d), device=d)),
lambda t, d: [2],
- 1e-3, 1e-5, 1e-3, _signed_types_no_half, _cpu_types, False),
+ 1e-3, 1e-5, 1e-3, _signed_types, _cpu_types, False),
('__rshift__', '',
lambda t, d: torch.pow(2, torch.arange(3, 7).to(dtype=_convert_t(t, d), device=d)),
lambda t, d: [2],
- 1e-3, 1e-5, 1e-3, _signed_types_no_half, _cpu_types, False),
+ 1e-3, 1e-5, 1e-3, _signed_types, _cpu_types, False),
# lapack tests
('qr', 'square', _small_2d, lambda t, d: [],
1e-5, 1e-5, 3e-4, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]),