OpInfo: clone, contiguous (#58390)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/54261
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58390
Reviewed By: soulitzer
Differential Revision: D28567821
Pulled By: mruberry
fbshipit-source-id: bcf42cb4a9a57d8a15a76819b8a9e2df97cf00be
diff --git a/test/test_fx.py b/test/test_fx.py
index 6414318..7cf9190 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -2604,6 +2604,7 @@
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
# Sorted and one entry on each line to minimize merge conflicts.
known_no_schema = {'cdist',
+ 'contiguous',
'dstack',
'einsum',
'expand',
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index e3c1023..4ad874c 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -1313,7 +1313,8 @@
@ops(op_db, allowed_dtypes=(torch.float,))
def test_normalize_operator_exhaustive(self, device, dtype, op):
# Sorted and one entry on each line to minimize merge conflicts.
- op_skip = {'einsum',
+ op_skip = {'contiguous',
+ 'einsum',
'expand',
'expand_as',
'gradient',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 64b9c7d..ac2a1ef 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3465,6 +3465,26 @@
return samples
+def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ def generator():
+ yield SampleInput(make_arg((S, M, S)))
+ yield SampleInput(make_arg(()))
+
+ return list(generator())
+
+
+def sample_inputs_contiguous(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ def generator():
+ yield SampleInput(make_arg((S, S)))
+ yield SampleInput(make_arg((S, S), noncontiguous=True))
+
+ return list(generator())
+
+
def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -4173,6 +4193,21 @@
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
sample_inputs_func=sample_inputs_chunk,
supports_out=False),
+ OpInfo('clone',
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+ sample_inputs_func=sample_inputs_clone,
+ supports_forward_ad=True,
+ supports_out=False),
+ OpInfo('contiguous',
+ op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+ sample_inputs_func=sample_inputs_contiguous,
+ supports_forward_ad=True,
+ skips=(
+ # JIT has issue when op is passed as lambda
+ SkipInfo('TestCommon', 'test_variant_consistency_jit'),
+ ),
+ supports_out=False),
OpInfo('symeig',
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
@@ -6510,10 +6545,6 @@
('norm', (), (3, 0), '3_dim_scalar', (), [1]),
('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', (), [1]),
('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]),
- ('clone', (S, M, S), NO_ARGS),
- ('clone', (), NO_ARGS, 'scalar'),
- ('contiguous', (S, S), NO_ARGS, '', (True,)),
- ('contiguous', torch.randn(S, S).transpose(0, 1), NO_ARGS, 'not_contiguous', (True,)),
('diag_embed', (S, S), NO_ARGS),
('diagonal', (M, M), NO_ARGS, '2d'),
('diagonal', (3, 5), NO_ARGS, '2d_wide'),