Add Opinfos for the Tensor overload of linspace/logspace (#107958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107958
Approved by: https://github.com/zou3519
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 52fb185..d37f2b4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1004,6 +1004,30 @@
yield SampleInput(1, args=(3, 1))
+def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+ ends = (-3, 0, 1, 4, 50)
+ starts = (-2., 0, 4.3, 50)
+ nsteps = (0, 1, 50)
+ is_start_end_tensors = ((True, True), (True, False), (False, True))
+ make_arg = partial(torch.tensor, device=device, requires_grad=False)
+
+ # Extra case to replicate off-by-one issue on CUDA
+ cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))]
+ for start, end, nstep, (is_start_tensor, is_end_tensor) in cases:
+ if dtype == torch.uint8 and (end < 0 or start < 0):
+ continue
+
+ tensor_options = {"dtype": dtype, "device": device}
+ if is_start_tensor:
+ start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+ if is_end_tensor:
+ end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+ yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+
+ yield SampleInput(1, args=(3, 1))
+
+
def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
ends = (-3, 0, 1.2, 2, 4)
starts = (-2., 0, 1, 2, 4.3)
@@ -1023,6 +1047,35 @@
yield SampleInput(1, args=(3, 1, 2.))
+def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+ ends = (-3, 0, 1.2, 2, 4)
+ starts = (-2., 0, 1, 2, 4.3)
+ nsteps = (0, 1, 2, 4)
+ bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
+ is_start_end_tensors = ((True, True), (True, False), (False, True))
+ make_arg = partial(torch.tensor, device=device)
+ for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors):
+ if dtype == torch.uint8 and end < 0 or start < 0:
+ continue
+ if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
+ # https://github.com/pytorch/pytorch/issues/82242
+ continue
+
+ tensor_options = {"dtype": dtype, "device": device}
+
+ if (is_start_tensor):
+ start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+ if (is_end_tensor):
+ end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+ if base is None:
+ yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+ else:
+ yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options)
+
+ yield SampleInput(1, args=(3, 1, 2.))
+
+
def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
@@ -11317,6 +11370,45 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
dtypes=(torch.cfloat,), device_type="cuda"),
)),
+ OpInfo('linspace',
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+ is_factory_function=True,
+ supports_out=True,
+ supports_autograd=False,
+ error_inputs_func=error_inputs_linspace,
+ sample_inputs_func=sample_inputs_linspace_tensor_overload,
+ variant_test_name="tensor_overload",
+ skips=(
+ # FX failed to normalize op - add the op to the op_skip list.
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # TypeError: 'int' object is not subscriptable
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+ # Same failure as arange: cannot find linspace in captured graph
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+ # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+ # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+ # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+ # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+ dtypes=(torch.cfloat,), device_type="cuda"),
+
+ # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+ DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+ DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+ )),
OpInfo('logspace',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
@@ -11349,6 +11441,51 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
dtypes=(torch.cfloat,), device_type="cuda"),
)),
+ OpInfo('logspace',
+ dtypes=all_types_and_complex_and(torch.bfloat16),
+ dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
+ is_factory_function=True,
+ supports_out=True,
+ supports_autograd=False,
+ error_inputs_func=error_inputs_linspace,
+ sample_inputs_func=sample_inputs_logspace_tensor_overload,
+ variant_test_name="tensor_overload",
+ skips=(
+ # FX failed to normalize op - add the op to the op_skip list.
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # TypeError: 'int' object is not subscriptable
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+ # Same failure as arange: cannot find linspace in captured graph
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+
+ # Off-by-one issue when casting floats to ints
+ DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
+ dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+ DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
+ dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+ # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+ # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+ # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+ # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+ dtypes=(torch.cfloat,), device_type="cuda"),
+
+ # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+ DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+ DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+ )),
UnaryUfuncInfo('log',
ref=np.log,
domain=(0, None),
@@ -19000,6 +19137,41 @@
),
),
PythonRefInfo(
+ "_refs.linspace",
+ torch_opinfo_name="linspace",
+ torch_opinfo_variant_name="tensor_overload",
+ skips=(
+ # TypeError: 'int' object is not subscriptable
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+ # cpu implementation is wrong on some integral types
+ # https://github.com/pytorch/pytorch/issues/81996
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+ dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+ dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+
+ # cuda implementation is off-by-one on some inputs due to precision issues
+ # https://github.com/pytorch/pytorch/issues/82230
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+ dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+ dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+ dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+
+ # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+ dtypes=(torch.float64, torch.complex128)),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+ ),
+ ),
+ PythonRefInfo(
"_refs.logspace",
torch_opinfo_name="logspace",
skips=(
@@ -19021,6 +19193,33 @@
),
),
PythonRefInfo(
+ "_refs.logspace",
+ torch_opinfo_name="logspace",
+ torch_opinfo_variant_name="tensor_overload",
+ skips=(
+ # TypeError: 'int' object is not subscriptable
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+ # Off-by-one issue when casting floats to ints
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+ dtypes=(torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+ dtypes=(torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+ dtypes=(torch.int16, torch.int32, torch.int64),
+ device_type="cuda"),
+
+ # https://github.com/pytorch/pytorch/pull/107958#pullrequestreview-1611367760
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+ dtypes=(torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8)),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+ ),
+ ),
+ PythonRefInfo(
"_refs.meshgrid",
torch_opinfo_name="meshgrid",
torch_opinfo_variant_name="variadic_tensors",