Adds OpInfos for max_unpool{1, 2, 3}d

per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75879
Approved by: https://github.com/ngimel
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 33f64ac..41be677 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8039,6 +8039,79 @@
         SampleInput(make_input((2,)), kwargs=dict(offset=-1)),
     ]
 
+def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+    unpool_name_to_pool_method_dict = {
+        'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
+        'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
+        'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
+    }
+
+    unpool_name_to_dim = {
+        'nn.functional.max_unpool1d': 1,
+        'nn.functional.max_unpool2d': 2,
+        'nn.functional.max_unpool3d': 3
+    }
+
+    unpool_to_pool_name_dict = dict((
+        (k, f'nn.functional.{v.__name__}') for k, v in unpool_name_to_pool_method_dict.items()
+    ))
+
+    pool_dim = unpool_name_to_dim[op_info.name]
+    pool_method = unpool_name_to_pool_method_dict[op_info.name]
+
+    pool_op_info = copy.copy(op_info)
+    pool_op_info.name = unpool_to_pool_name_dict[op_info.name]
+
+    for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs):
+        # shapes (C, ...) do not work as of now,
+        # see https://github.com/pytorch/pytorch/issues/68337
+        # TODO: remove once the issue is resolved
+        if sample.input.dim() != pool_dim + 2:
+            continue
+
+        # No dilation > 1 for max_unpool,
+        # see https://github.com/pytorch/pytorch/issues/68420
+        if sample.kwargs['dilation'] != 1:
+            continue
+
+        # Can't unpool without indices
+        if sample.kwargs['return_indices']:
+            pool, indices = pool_method(sample.input, **sample.kwargs)
+            # arg has to be a leaf
+            arg = pool.detach().requires_grad_(requires_grad)
+            sample_kwargs = {
+                'kernel_size': sample.kwargs['kernel_size'],
+                'stride': sample.kwargs['stride'],
+                'padding': sample.kwargs['padding'],
+                # output_size could be None but we specify it explicitly
+                # to compensate for the information lose in pool due
+                # to the floor/ceil operation used to compute the shapes
+                'output_size': sample.input.size()
+            }
+
+            yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs)
+
+def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs):
+    for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+        indices = sample.args[0]
+        # The samples for max_unpool are generated with max_pool.
+        # It could be that a single element from the max_pool's
+        # input is mapped to several locations in its output.
+        # This situation leads to failed gradchecks because
+        # the finite difference algorithm perturbes the elements
+        # of the output one by one, and not in classes of
+        # equivalences determined by whether two elements
+        # in the output are coming from the same location in the
+        # input (simply put, they have the same corresponding index).
+        # So, there are two ways to resolve this issue:
+        # 1. Extract a pertubation for one element and apply it all
+        #    the elements from the same equivalence class, or
+        # 2. Make sure that the equivalence classes are all singletons,
+        # i.e. the index tensor has to be comprised of only unique
+        # indices.
+        # Here we go with the solution 2, the easiest of all.
+        if indices.unique().numel() == indices.numel():
+            yield sample
 
 foreach_unary_op_db: List[OpInfo] = [
     ForeachFuncInfo('exp'),
@@ -11984,6 +12057,87 @@
            # TODO: investigate nondeterminism
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
            sample_inputs_func=sample_inputs_max_pool),
+    OpInfo('nn.functional.max_unpool1d',
+           aten_name='max_unpool1d',
+           supports_autograd=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               # Jacobian mismatch
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+               # Backward is not reentrant
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+           )),
+    OpInfo('nn.functional.max_unpool1d',
+           variant_test_name='grad',
+           aten_name='max_unpool1d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool_grad,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+           )),
+    OpInfo('nn.functional.max_unpool2d',
+           aten_name='max_unpool2d',
+           supports_autograd=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+               # Backward is not reentrant
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+               # Jacobian mismatch
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+           )),
+    OpInfo('nn.functional.max_unpool2d',
+           variant_test_name='grad',
+           aten_name='max_unpool2d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           # Vmap is not happy with non-contiguous (channels_last) inputs
+           check_batched_grad=False,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool_grad,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+           )),
+    OpInfo('nn.functional.max_unpool3d',
+           aten_name='max_unpool3d',
+           supports_autograd=False,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               # Jacobian mismatch
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+           )),
+    OpInfo('nn.functional.max_unpool3d',
+           variant_test_name='grad',
+           aten_name='max_unpool3d',
+           supports_autograd=False,
+           supports_forward_ad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16),
+           sample_inputs_func=sample_inputs_max_unpool_grad),
     OpInfo('nn.functional.linear',
            aten_name='linear',
            supports_autograd=True,