CUDA BFloat16 pow (#44760)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44760

Reviewed By: ngimel

Differential Revision: D23727936

Pulled By: mruberry

fbshipit-source-id: 8aa89e989294347d7f593b1a63ce4a1dbfdf783e
diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu
index efd196c..3926b35 100644
--- a/aten/src/ATen/native/cuda/PowKernel.cu
+++ b/aten/src/ATen/native/cuda/PowKernel.cu
@@ -110,10 +110,8 @@
     });
   } else if (isFloatingType(iter.dtype())) {
     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "pow_cuda", [&]() {
-      AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
-        gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
-          return pow_(base, exp);
-        });
+      gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
+        return pow_(base, exp);
       });
     });
   } else {
@@ -170,10 +168,8 @@
     });
   } else if (isFloatingType(iter.dtype()) || exp_scalar.isIntegral(false)) {
     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "pow_cuda", [&]() {
-      AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
-        const auto exp = exp_scalar.to<scalar_t>();
-        pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
-      });
+      const auto exp = exp_scalar.to<scalar_t>();
+      pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
     });
   } else {
     const auto exp = exp_scalar.to<float>();
diff --git a/test/test_torch.py b/test/test_torch.py
index 860f1b3..9819ac4 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -19800,15 +19800,15 @@
     ('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types),
     ('floor_divide', 'tensor', _small_3d,
         lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types),
-    ('pow', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, 1e-1, 1e-5, _float_types2),
-    ('pow', '1', _small_3d, lambda t, d: [_number(1., 1, t)], 1e-1, 1e-1, 1e-5, _float_types2),
-    ('pow', '2', _small_3d, lambda t, d: [_number(2., 2, t)], 1e-1, 1e-1, 1e-5, _float_types2),
-    ('pow', '3', _small_3d, lambda t, d: [_number(3., 3, t)], 1e-1, 1e-1, 1e-5, _float_types2),
-    ('pow', '-1', _small_3d, lambda t, d: [_number(-1., -1, t)], 1e-1, 1e-1, 1e-5, _float_types2),
+    ('pow', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+    ('pow', '1', _small_3d, lambda t, d: [_number(1., 1, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+    ('pow', '2', _small_3d, lambda t, d: [_number(2., 2, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+    ('pow', '3', _small_3d, lambda t, d: [_number(3., 3, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+    ('pow', '-1', _small_3d, lambda t, d: [_number(-1., -1, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
     ('pow', '-2', _small_3d, lambda t, d: [_number(-2., -2, t)],
         1e-1, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False),
     ('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()],
-        1e-1, 1e-1, 1e-5, _float_types2),
+        1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
     ('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
         1e-1, 1e-1, 1e-4, _complex_and_float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
     ('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],