OpInfos for new_blah functions and some _like functions (#67357)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67357
This PR adds OpInfos for:
- new_ones, new_zeros, new_full, new_empty
- rand_like, randint_like
I forgot to add the _like functions in a previous PR, so here they are.
Test Plan: - wait for tests
Reviewed By: mruberry
Differential Revision: D31969533
Pulled By: zou3519
fbshipit-source-id: 236d70d66e82f1d6f8e5254b55ca2a37b54c9494
diff --git a/test/test_fx.py b/test/test_fx.py
index 6627744..83b6ad5 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -3208,6 +3208,7 @@
raise RuntimeError(f'Did not match any schemas for op {op.name}!')
+
class TestFXAPIBackwardCompatibility(JitTestCase):
def setUp(self):
self.maxDiff = None
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 73735e7..851d0d58 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -1512,6 +1512,12 @@
'randn_like',
'zeros_like',
'full_like',
+ 'rand_like',
+ 'randint_like',
+ 'new_ones',
+ 'new_empty',
+ 'new_zeros',
+ 'new_full',
"__getitem__",
"__radd__",
"__rsub__",
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index eb3a72d..040268d 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2125,6 +2125,64 @@
return tuple(samples)
+def get_independent_tensor(tensor):
+ return tensor.detach().clone().requires_grad_(tensor.requires_grad)
+
+def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
+ samples = []
+ low = 2
+ high = 10
+
+ for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
+ # With high
+ samples.append(SampleInput(
+ sample.input,
+ args=(high,) + sample.args,
+ kwargs=sample.kwargs))
+ # With low and high
+ samples.append(SampleInput(
+ get_independent_tensor(sample.input),
+ args=(low, high,) + sample.args,
+ kwargs=sample.kwargs))
+ return tuple(samples)
+
+def sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
+ inputs = [
+ ((), (), {}),
+ ((S, S), (2, 0), {}),
+ ((0, S, 0), (3, 2, 2), {}),
+ ((S,), (2, 3), {'dtype': dtype, 'device': device}),
+ # Hard-code some dtypes/devices. We want to test cases where the
+ # (dtype, device) is different from the input's (dtype, device)
+ ((S,), (10,), {'dtype': torch.double}),
+ ((S,), (1, 1, 12), {'device': 'cpu'}),
+ ((S,), (2, 2, 2), {'dtype': torch.double, 'device': 'cpu'}),
+ ]
+ if torch.cuda.is_available():
+ inputs.append(((S,), (7, 2), {'device': 'cuda'}))
+
+ samples = []
+ for input_shape, output_shape, kwargs in inputs:
+ t = make_tensor(input_shape, device, dtype,
+ low=None, high=None,
+ requires_grad=requires_grad)
+ samples.append(SampleInput(t, args=(output_shape,), kwargs=kwargs))
+
+ return tuple(samples)
+
+def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
+ def get_val(dtype):
+ return make_tensor([], 'cpu', dtype).item()
+
+ samples = []
+ for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
+ # The scalar we are passing to new_full must be the same dtype
+ # as the one of the resulting tensor
+ use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
+ samples.append(SampleInput(
+ sample.input, args=sample.args + (get_val(use_dtype),), kwargs=sample.kwargs))
+ return tuple(samples)
+
def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
def get_val(dtype):
return make_tensor([], 'cpu', dtype).item()
@@ -7099,7 +7157,6 @@
torch.manual_seed(42)
return op(input, *args, **kwargs)
-
def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
feature_size = np.prod(normalized_shape)
inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
@@ -11206,6 +11263,32 @@
# AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
+ OpInfo('rand_like',
+ dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex64, torch.complex128),
+ op=lambda inp, *args, **kwargs:
+ wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_like_fns,
+ supports_autograd=False,
+ skips=(
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ )),
+ OpInfo('randint_like',
+ dtypes=all_types_and(torch.half, torch.bfloat16),
+ op=lambda inp, *args, **kwargs:
+ wrapper_set_seed(torch.randint_like, inp, *args, **kwargs),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_randint_like,
+ supports_autograd=False,
+ skips=(
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ )),
OpInfo('full_like',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
@@ -11215,6 +11298,56 @@
# Can't find schemas for this operator for some reason
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)),
+ OpInfo('new_zeros',
+ op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_new_fns,
+ skips=(
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ ),
+ supports_autograd=False),
+ OpInfo('new_ones',
+ op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_new_fns,
+ skips=(
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ ),
+ supports_autograd=False),
+ OpInfo('new_empty',
+ op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_new_fns,
+ skips=(
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ ),
+ supports_autograd=False),
+ OpInfo('new_full',
+ op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_new_full,
+ skips=(
+ # Can't find schemas for this operator for some reason
+ DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+ ),
+ supports_autograd=False),
OpInfo('scatter_add',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_scatter_add,