[bfloat16][easy] kthvalue, median (#117279)
Fixes #109991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117279
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu
index 313c6d1..385fc97 100644
--- a/aten/src/ATen/native/cuda/Sorting.cu
+++ b/aten/src/ATen/native/cuda/Sorting.cu
@@ -247,8 +247,8 @@
void launch_kthvalue_kernel(
const TensorBase &values, const TensorBase &indices,
const TensorBase &self, int64_t dim, int64_t k) {
- AT_DISPATCH_ALL_TYPES_AND(
- at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(
+ at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "kthvalue_cuda", [&] {
AT_DISPATCH_INDEX_TYPES(
cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
@@ -263,8 +263,8 @@
void launch_median_kernel(
const TensorBase &vals, const TensorBase &inds,
const TensorBase &self, int64_t dim, bool ignore_nan) {
- AT_DISPATCH_ALL_TYPES_AND(
- at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(
+ at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "median_out_impl", [&] {
if (cuda::detail::canUse32BitIndexMath(vals) &&
cuda::detail::canUse32BitIndexMath(inds) &&
cuda::detail::canUse32BitIndexMath(self)) {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index f029151..b4fb1ef 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11911,7 +11911,6 @@
sample_inputs_func=sample_inputs_isin),
OpInfo('kthvalue',
dtypes=all_types_and(torch.bfloat16, torch.float16),
- dtypesIfCUDA=all_types_and(torch.float16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_kthvalue,
@@ -12336,7 +12335,6 @@
)),
OpInfo('median',
dtypes=all_types_and(torch.bfloat16, torch.float16),
- dtypesIfCUDA=all_types_and(torch.float16),
# TODO: some signatures of median do support out
supports_out=False,
supports_forward_ad=True,
@@ -12345,7 +12343,6 @@
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
OpInfo('nanmedian',
dtypes=all_types_and(torch.bfloat16, torch.float16),
- dtypesIfCUDA=all_types_and(torch.float16),
# TODO: some signatures of nanmedian do support out
supports_out=False,
supports_forward_ad=True,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 11298dd..b99368d 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -834,7 +834,6 @@
OpInfo(
"masked.median",
dtypes=floating_types_and(torch.bfloat16, torch.float16),
- dtypesIfCUDA=floating_types_and(torch.float16),
method_variant=None,
supports_out=False,
supports_forward_ad=True,