OpInfo: select (#57731)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/54261
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57731
Reviewed By: bdhirsh
Differential Revision: D28318229
Pulled By: mruberry
fbshipit-source-id: ec9058fd188b82de80d3a2f1a1ba07f36d8d0741
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 20d76dc..53875a4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3353,6 +3353,23 @@
return list(generator())
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ cases = (((S, S, S), (1, 2)),
+ ((S, S, S), (-1, 2)),
+ ((S, S, S), (-1, -1)),
+ ((S, S, S), (1, -1)),
+ ((S,), (0, 2))
+ )
+
+ def generator():
+ for shape, args in cases:
+ yield SampleInput(make_arg(shape), args=args)
+
+ return list(generator())
+
+
def sample_inputs_rbinops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
def _make_tensor_helper(shape, low=None, high=None):
return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
@@ -5094,6 +5111,10 @@
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half)),),
assert_autodiffed=True,),
+ OpInfo('select',
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+ sample_inputs_func=sample_inputs_select,
+ supports_out=False),
UnaryUfuncInfo('signbit',
ref=np.signbit,
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
@@ -6211,9 +6232,6 @@
('fill_', (S, S, S), (1,), 'number'),
('fill_', (), (1,), 'number_scalar'),
('fill_', (S, S, S), ((),), 'variable'),
- ('select', (S, S, S), (1, 2), 'dim', (), [0]),
- ('select', (S, S, S), (1, -1), 'wrap_dim', (), [0]),
- ('select', (S,), (0, 2), '1d'),
('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]),
('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]),
('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)),