[chalf] index_select: cpu support (#79217)
Fixes https://github.com/pytorch/pytorch/issues/79204
PR https://github.com/pytorch/pytorch/pull/78173 took care of adding CUDA support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79217
Approved by: https://github.com/mruberry
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 29a4dea..7c20844 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -1174,7 +1174,7 @@
});
});
} else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
self.scalar_type(), "index_select", [&index_contig, &self, &result, &dim, &numel] {
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index d1aa20c..6ad92a6 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2469,11 +2469,9 @@
yield from generate_elementwise_binary_large_value_tensors(
op, device=device, dtype=dtype, requires_grad=requires_grad
)
- # TODO: FIXME: RuntimeError: "index_select" not implemented for 'ComplexHalf'
- if dtype not in (torch.chalf,):
- yield from generate_elementwise_binary_broadcasting_tensors(
- op, device=device, dtype=dtype, requires_grad=requires_grad, exclude_zero=exclude_zero
- )
+ yield from generate_elementwise_binary_broadcasting_tensors(
+ op, device=device, dtype=dtype, requires_grad=requires_grad, exclude_zero=exclude_zero
+ )
yield from generate_elementwise_binary_with_scalar_samples(
op, device=device, dtype=dtype, requires_grad=requires_grad
)
@@ -2501,10 +2499,6 @@
# yields "normal" samples
yield from gen()
- # TODO: RuntimeError: "index_select" not implemented for 'ComplexHalf'
- if dtype is torch.chalf:
- return
-
# yields noncontiguous samples
for sample in gen():
yield sample.noncontiguous()
@@ -16520,7 +16514,8 @@
sample_inputs_func=sample_inputs_index,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
OpInfo('index_select',
- dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+ backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_index,
error_inputs_func=error_inputs_index_select,
supports_forward_ad=True,
@@ -16776,9 +16771,6 @@
DecorateInfo(unittest.expectedFailure, 'TestGradients'),
# use of lambda doesn't work with test_normalize_operator_exhaustive
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
- # RuntimeError: "index_select" not implemented for 'ComplexHalf'
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples',
- dtypes=(torch.float, torch.cfloat), device_type='cpu'),
# RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager',
device_type='cpu'),
@@ -19228,7 +19220,8 @@
),
OpInfo(
"repeat_interleave",
- dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+ backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_repeat_interleave,
supports_out=False,
supports_forward_ad=True,