OpInfo: select (#57731)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/54261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57731

Reviewed By: bdhirsh

Differential Revision: D28318229

Pulled By: mruberry

fbshipit-source-id: ec9058fd188b82de80d3a2f1a1ba07f36d8d0741
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 20d76dc..53875a4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3353,6 +3353,23 @@
     return list(generator())
 
 
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((S, S, S), (1, 2)),
+             ((S, S, S), (-1, 2)),
+             ((S, S, S), (-1, -1)),
+             ((S, S, S), (1, -1)),
+             ((S,), (0, 2))
+             )
+
+    def generator():
+        for shape, args in cases:
+            yield SampleInput(make_arg(shape), args=args)
+
+    return list(generator())
+
+
 def sample_inputs_rbinops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
     def _make_tensor_helper(shape, low=None, high=None):
         return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
@@ -5094,6 +5111,10 @@
                SkipInfo('TestCommon', 'test_variant_consistency_jit',
                         dtypes=all_types_and_complex_and(torch.bfloat16, torch.half)),),
            assert_autodiffed=True,),
+    OpInfo('select',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+           sample_inputs_func=sample_inputs_select,
+           supports_out=False),
     UnaryUfuncInfo('signbit',
                    ref=np.signbit,
                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
@@ -6211,9 +6232,6 @@
         ('fill_', (S, S, S), (1,), 'number'),
         ('fill_', (), (1,), 'number_scalar'),
         ('fill_', (S, S, S), ((),), 'variable'),
-        ('select', (S, S, S), (1, 2), 'dim', (), [0]),
-        ('select', (S, S, S), (1, -1), 'wrap_dim', (), [0]),
-        ('select', (S,), (0, 2), '1d'),
         ('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]),
         ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]),
         ('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)),