CUDA BFloat16 pow (#44760)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44760
Reviewed By: ngimel
Differential Revision: D23727936
Pulled By: mruberry
fbshipit-source-id: 8aa89e989294347d7f593b1a63ce4a1dbfdf783e
diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu
index efd196c..3926b35 100644
--- a/aten/src/ATen/native/cuda/PowKernel.cu
+++ b/aten/src/ATen/native/cuda/PowKernel.cu
@@ -110,10 +110,8 @@
});
} else if (isFloatingType(iter.dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "pow_cuda", [&]() {
- AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
- gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
- return pow_(base, exp);
- });
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
+ return pow_(base, exp);
});
});
} else {
@@ -170,10 +168,8 @@
});
} else if (isFloatingType(iter.dtype()) || exp_scalar.isIntegral(false)) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "pow_cuda", [&]() {
- AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "pow_cuda", [&] {
- const auto exp = exp_scalar.to<scalar_t>();
- pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
- });
+ const auto exp = exp_scalar.to<scalar_t>();
+ pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
});
} else {
const auto exp = exp_scalar.to<float>();
diff --git a/test/test_torch.py b/test/test_torch.py
index 860f1b3..9819ac4 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -19800,15 +19800,15 @@
('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types),
('floor_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types),
- ('pow', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, 1e-1, 1e-5, _float_types2),
- ('pow', '1', _small_3d, lambda t, d: [_number(1., 1, t)], 1e-1, 1e-1, 1e-5, _float_types2),
- ('pow', '2', _small_3d, lambda t, d: [_number(2., 2, t)], 1e-1, 1e-1, 1e-5, _float_types2),
- ('pow', '3', _small_3d, lambda t, d: [_number(3., 3, t)], 1e-1, 1e-1, 1e-5, _float_types2),
- ('pow', '-1', _small_3d, lambda t, d: [_number(-1., -1, t)], 1e-1, 1e-1, 1e-5, _float_types2),
+ ('pow', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+ ('pow', '1', _small_3d, lambda t, d: [_number(1., 1, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+ ('pow', '2', _small_3d, lambda t, d: [_number(2., 2, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+ ('pow', '3', _small_3d, lambda t, d: [_number(3., 3, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
+ ('pow', '-1', _small_3d, lambda t, d: [_number(-1., -1, t)], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('pow', '-2', _small_3d, lambda t, d: [_number(-2., -2, t)],
1e-1, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False),
('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()],
- 1e-1, 1e-1, 1e-5, _float_types2),
+ 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1e-1, 1e-1, 1e-4, _complex_and_float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],