[BC] Allow only `bool` tensors as mask in `masked_select` (#96112)
`byte` support was marked as deprecated in 1.8, so it's fine to remove this in 2.1 (or even 2.0)
Deprecation warning was added by https://github.com/pytorch/pytorch/pull/22261
Also, fix bunch of syntactic errors in comments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96112
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 24ea406..aaa5f8f 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -5,7 +5,7 @@
// index(Tensor self, indices) -> Tensor
// index_put_(Tensor self, indices, value, accumulate=false)
//
-// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
+// The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null
// tensors signify that the dimension is not indexed.
//
@@ -1842,8 +1842,8 @@
static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;
- TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
- "masked_select: expected BoolTensor or ByteTensor for mask");
+ TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+ "masked_select: expected BoolTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");
@@ -1851,11 +1851,6 @@
at::assert_no_overlap(result, self);
at::assert_no_overlap(result, mask);
- if (mask.dtype() == at::ScalarType::Byte) {
- TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
- "please use a mask with dtype torch.bool instead.");
- }
-
c10::MaybeOwned<Tensor> _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
@@ -1880,7 +1875,7 @@
_self->is_contiguous() && _mask->is_contiguous();
if (use_serial_kernel) {
auto iter = TensorIteratorConfig()
- .set_check_mem_overlap(false) // result is intenionally zero-strided above
+ .set_check_mem_overlap(false) // result is intentionally zero-strided above
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(result_strided)
@@ -1899,12 +1894,12 @@
auto mask_long_data = mask_long.data_ptr<int64_t>();
auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>();
// TODO: Here can only use std::partial_sum for C++14,
- // use std::exclusive_scan when PyTorch upgrades to C++17, which have better peformance.
+ // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance.
// std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0);
std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
auto iter = TensorIteratorConfig()
- .set_check_mem_overlap(false) // result is intenionally zero-strided above
+ .set_check_mem_overlap(false) // result is intentionally zero-strided above
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(result_strided)
diff --git a/aten/src/ATen/native/cuda/IndexKernel.cpp b/aten/src/ATen/native/cuda/IndexKernel.cpp
index b5e92a1..b0337b9 100644
--- a/aten/src/ATen/native/cuda/IndexKernel.cpp
+++ b/aten/src/ATen/native/cuda/IndexKernel.cpp
@@ -24,8 +24,8 @@
static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;
- TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
- "masked_select: expected BoolTensor or ByteTensor for mask");
+ TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+ "masked_select: expected BoolTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 0af63e1a..74d6bde 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -189,8 +189,8 @@
static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;
- TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
- "masked_select: expected BoolTensor or ByteTensor for mask");
+ TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
+ "masked_select: expected BoolTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");
diff --git a/test/test_torch.py b/test/test_torch.py
index 12dea3b..076f764 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -55,7 +55,7 @@
tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN)
from torch.testing._internal.common_dtype import (
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
- all_types_and, floating_types, floating_and_complex_types,
+ all_types_and, floating_types, floating_and_complex_types, integral_types_and
)
# Protects against includes accidentally setting the default dtype
@@ -629,12 +629,6 @@
zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device)
one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device)
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- self.assertEqual((1,), torch.masked_select(zero_d_uint8, zero_d_uint8).shape)
- self.assertEqual((1,), torch.masked_select(zero_d_uint8, one_d_uint8).shape)
- self.assertEqual((1,), torch.masked_select(one_d_uint8, zero_d_uint8).shape)
-
# mode
self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)])
self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)])
@@ -3688,16 +3682,17 @@
warn = 'masked_select received a mask with dtype torch.uint8,'
else:
warn = 'indexing with dtype torch.uint8 is now deprecated, pl'
- for maskType in [torch.uint8, torch.bool]:
+ for maskType in integral_types_and(torch.bool):
num_src = 10
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
mask = torch.randint(2, (num_src,), device=device, dtype=maskType)
- with warnings.catch_warnings(record=True) as w:
+ if maskType is not torch.bool:
+ with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'):
+ dst = src.masked_select(mask)
+ continue
+ else:
dst = src.masked_select(mask)
- if maskType is torch.uint8:
- self.assertEqual(len(w), 1)
- self.assertEqual(str(w[0].message)[0:53], str(warn))
dst2 = []
for i in range(num_src):
if mask[i]: