[bfloat16][easy] kthvalue, median (#117279)

Fixes #109991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117279
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu
index 313c6d1..385fc97 100644
--- a/aten/src/ATen/native/cuda/Sorting.cu
+++ b/aten/src/ATen/native/cuda/Sorting.cu
@@ -247,8 +247,8 @@
 void launch_kthvalue_kernel(
     const TensorBase &values, const TensorBase &indices,
     const TensorBase &self, int64_t dim, int64_t k) {
-  AT_DISPATCH_ALL_TYPES_AND(
-      at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(
+      at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "kthvalue_cuda", [&] {
     AT_DISPATCH_INDEX_TYPES(
         cuda::detail::canUse32BitIndexMath(self) &&
         cuda::detail::canUse32BitIndexMath(values) &&
@@ -263,8 +263,8 @@
 void launch_median_kernel(
     const TensorBase &vals, const TensorBase &inds,
     const TensorBase &self, int64_t dim, bool ignore_nan) {
-  AT_DISPATCH_ALL_TYPES_AND(
-      at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(
+      at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "median_out_impl", [&] {
         if (cuda::detail::canUse32BitIndexMath(vals) &&
             cuda::detail::canUse32BitIndexMath(inds) &&
             cuda::detail::canUse32BitIndexMath(self)) {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index f029151..b4fb1ef 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11911,7 +11911,6 @@
            sample_inputs_func=sample_inputs_isin),
     OpInfo('kthvalue',
            dtypes=all_types_and(torch.bfloat16, torch.float16),
-           dtypesIfCUDA=all_types_and(torch.float16),
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
            sample_inputs_func=sample_inputs_kthvalue,
@@ -12336,7 +12335,6 @@
            )),
     OpInfo('median',
            dtypes=all_types_and(torch.bfloat16, torch.float16),
-           dtypesIfCUDA=all_types_and(torch.float16),
            # TODO: some signatures of median do support out
            supports_out=False,
            supports_forward_ad=True,
@@ -12345,7 +12343,6 @@
            sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
     OpInfo('nanmedian',
            dtypes=all_types_and(torch.bfloat16, torch.float16),
-           dtypesIfCUDA=all_types_and(torch.float16),
            # TODO: some signatures of nanmedian do support out
            supports_out=False,
            supports_forward_ad=True,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 11298dd..b99368d 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -834,7 +834,6 @@
     OpInfo(
         "masked.median",
         dtypes=floating_types_and(torch.bfloat16, torch.float16),
-        dtypesIfCUDA=floating_types_and(torch.float16),
         method_variant=None,
         supports_out=False,
         supports_forward_ad=True,