Add Opinfos for the Tensor overload of linspace/logspace (#107958)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107958
Approved by: https://github.com/zou3519
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 52fb185..d37f2b4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1004,6 +1004,30 @@
     yield SampleInput(1, args=(3, 1))
 
 
+def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1, 4, 50)
+    starts = (-2., 0, 4.3, 50)
+    nsteps = (0, 1, 50)
+    is_start_end_tensors = ((True, True), (True, False), (False, True))
+    make_arg = partial(torch.tensor, device=device, requires_grad=False)
+
+    # Extra case to replicate off-by-one issue on CUDA
+    cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))]
+    for start, end, nstep, (is_start_tensor, is_end_tensor) in cases:
+        if dtype == torch.uint8 and (end < 0 or start < 0):
+            continue
+
+        tensor_options = {"dtype": dtype, "device": device}
+        if is_start_tensor:
+            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+        if is_end_tensor:
+            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+        yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+
+    yield SampleInput(1, args=(3, 1))
+
+
 def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
     ends = (-3, 0, 1.2, 2, 4)
     starts = (-2., 0, 1, 2, 4.3)
@@ -1023,6 +1047,35 @@
     yield SampleInput(1, args=(3, 1, 2.))
 
 
+def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1.2, 2, 4)
+    starts = (-2., 0, 1, 2, 4.3)
+    nsteps = (0, 1, 2, 4)
+    bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
+    is_start_end_tensors = ((True, True), (True, False), (False, True))
+    make_arg = partial(torch.tensor, device=device)
+    for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors):
+        if dtype == torch.uint8 and end < 0 or start < 0:
+            continue
+        if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
+            # https://github.com/pytorch/pytorch/issues/82242
+            continue
+
+        tensor_options = {"dtype": dtype, "device": device}
+
+        if (is_start_tensor):
+            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+        if (is_end_tensor):
+            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+        if base is None:
+            yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+        else:
+            yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options)
+
+    yield SampleInput(1, args=(3, 1, 2.))
+
+
 def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
     yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
 
@@ -11317,6 +11370,45 @@
                DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
                             dtypes=(torch.cfloat,), device_type="cuda"),
            )),
+    OpInfo('linspace',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_linspace_tensor_overload,
+           variant_test_name="tensor_overload",
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # TypeError: 'int' object is not subscriptable
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+
+               # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+           )),
     OpInfo('logspace',
            dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
@@ -11349,6 +11441,51 @@
                DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
                             dtypes=(torch.cfloat,), device_type="cuda"),
            )),
+    OpInfo('logspace',
+           dtypes=all_types_and_complex_and(torch.bfloat16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_logspace_tensor_overload,
+           variant_test_name="tensor_overload",
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # TypeError: 'int' object is not subscriptable
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+
+               # Off-by-one issue when casting floats to ints
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+
+               # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+           )),
     UnaryUfuncInfo('log',
                    ref=np.log,
                    domain=(0, None),
@@ -19000,6 +19137,41 @@
         ),
     ),
     PythonRefInfo(
+        "_refs.linspace",
+        torch_opinfo_name="linspace",
+        torch_opinfo_variant_name="tensor_overload",
+        skips=(
+            # TypeError: 'int' object is not subscriptable
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+            # cpu implementation is wrong on some integral types
+            # https://github.com/pytorch/pytorch/issues/81996
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+
+            # cuda implementation is off-by-one on some inputs due to precision issues
+            # https://github.com/pytorch/pytorch/issues/82230
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+
+            # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.float64, torch.complex128)),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+        ),
+    ),
+    PythonRefInfo(
         "_refs.logspace",
         torch_opinfo_name="logspace",
         skips=(
@@ -19021,6 +19193,33 @@
         ),
     ),
     PythonRefInfo(
+        "_refs.logspace",
+        torch_opinfo_name="logspace",
+        torch_opinfo_variant_name="tensor_overload",
+        skips=(
+            # TypeError: 'int' object is not subscriptable
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+            # Off-by-one issue when casting floats to ints
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+
+            # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8)),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+        ),
+    ),
+    PythonRefInfo(
         "_refs.meshgrid",
         torch_opinfo_name="meshgrid",
         torch_opinfo_variant_name="variadic_tensors",