[BC] Allow only `bool` tensors as mask in `masked_select` (#96112)

`byte` support was marked as deprecated in 1.8, so it's fine to remove this in 2.1 (or even 2.0)
Deprecation warning was added by https://github.com/pytorch/pytorch/pull/22261

Also, fix bunch of syntactic errors in comments

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96112
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 24ea406..aaa5f8f 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -5,7 +5,7 @@
 //  index(Tensor self, indices) -> Tensor
 //  index_put_(Tensor self, indices, value, accumulate=false)
 //
-// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
+// The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte
 // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
 // tensors signify that the dimension is not indexed.
 //
@@ -1842,8 +1842,8 @@
 static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
   NoNamesGuard guard;
 
-  TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
-              "masked_select: expected BoolTensor or ByteTensor for mask");
+  TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+              "masked_select: expected BoolTensor for mask");
   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
               "masked_select(): self and result must have the same scalar type");
 
@@ -1851,11 +1851,6 @@
   at::assert_no_overlap(result, self);
   at::assert_no_overlap(result, mask);
 
-  if (mask.dtype() == at::ScalarType::Byte) {
-    TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
-            "please use a mask with dtype torch.bool instead.");
-  }
-
   c10::MaybeOwned<Tensor> _mask, _self;
   std::tie(_mask, _self) = expand_outplace(mask, self);
 
@@ -1880,7 +1875,7 @@
   _self->is_contiguous() && _mask->is_contiguous();
   if (use_serial_kernel) {
     auto iter = TensorIteratorConfig()
-      .set_check_mem_overlap(false)  // result is intenionally zero-strided above
+      .set_check_mem_overlap(false)  // result is intentionally zero-strided above
       .check_all_same_dtype(false)
       .resize_outputs(false)
       .add_output(result_strided)
@@ -1899,12 +1894,12 @@
   auto mask_long_data = mask_long.data_ptr<int64_t>();
   auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>();
   // TODO: Here can only use std::partial_sum for C++14,
-  // use std::exclusive_scan when PyTorch upgrades to C++17, which have better peformance.
+  // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance.
   // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0);
   std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
 
   auto iter = TensorIteratorConfig()
-    .set_check_mem_overlap(false)  // result is intenionally zero-strided above
+    .set_check_mem_overlap(false)  // result is intentionally zero-strided above
     .check_all_same_dtype(false)
     .resize_outputs(false)
     .add_output(result_strided)
diff --git a/aten/src/ATen/native/cuda/IndexKernel.cpp b/aten/src/ATen/native/cuda/IndexKernel.cpp
index b5e92a1..b0337b9 100644
--- a/aten/src/ATen/native/cuda/IndexKernel.cpp
+++ b/aten/src/ATen/native/cuda/IndexKernel.cpp
@@ -24,8 +24,8 @@
 static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
   NoNamesGuard guard;
 
-  TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
-              "masked_select: expected BoolTensor or ByteTensor for mask");
+  TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+              "masked_select: expected BoolTensor for mask");
   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
               "masked_select(): self and result must have the same scalar type");
 
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 0af63e1a..74d6bde 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -189,8 +189,8 @@
 static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
   NoNamesGuard guard;
 
-  TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
-              "masked_select: expected BoolTensor or ByteTensor for mask");
+  TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+              "masked_select: expected BoolTensor for mask");
   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
               "masked_select(): self and result must have the same scalar type");
 
diff --git a/test/test_torch.py b/test/test_torch.py
index 12dea3b..076f764 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -55,7 +55,7 @@
     tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN)
 from torch.testing._internal.common_dtype import (
     floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
-    all_types_and, floating_types, floating_and_complex_types,
+    all_types_and, floating_types, floating_and_complex_types, integral_types_and
 )
 
 # Protects against includes accidentally setting the default dtype
@@ -629,12 +629,6 @@
         zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device)
         one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device)
 
-        with warnings.catch_warnings():
-            warnings.simplefilter("ignore")
-            self.assertEqual((1,), torch.masked_select(zero_d_uint8, zero_d_uint8).shape)
-            self.assertEqual((1,), torch.masked_select(zero_d_uint8, one_d_uint8).shape)
-            self.assertEqual((1,), torch.masked_select(one_d_uint8, zero_d_uint8).shape)
-
         # mode
         self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)])
         self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)])
@@ -3688,16 +3682,17 @@
             warn = 'masked_select received a mask with dtype torch.uint8,'
         else:
             warn = 'indexing with dtype torch.uint8 is now deprecated, pl'
-        for maskType in [torch.uint8, torch.bool]:
+        for maskType in integral_types_and(torch.bool):
             num_src = 10
             src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
             mask = torch.randint(2, (num_src,), device=device, dtype=maskType)
 
-            with warnings.catch_warnings(record=True) as w:
+            if maskType is not torch.bool:
+                with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'):
+                    dst = src.masked_select(mask)
+                continue
+            else:
                 dst = src.masked_select(mask)
-                if maskType is torch.uint8:
-                    self.assertEqual(len(w), 1)
-                    self.assertEqual(str(w[0].message)[0:53], str(warn))
             dst2 = []
             for i in range(num_src):
                 if mask[i]: