Adds OpInfos for max_unpool{1, 2, 3}d
per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75879
Approved by: https://github.com/ngimel
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 33f64ac..41be677 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8039,6 +8039,79 @@
SampleInput(make_input((2,)), kwargs=dict(offset=-1)),
]
+def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+ unpool_name_to_pool_method_dict = {
+ 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
+ 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
+ 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
+ }
+
+ unpool_name_to_dim = {
+ 'nn.functional.max_unpool1d': 1,
+ 'nn.functional.max_unpool2d': 2,
+ 'nn.functional.max_unpool3d': 3
+ }
+
+ unpool_to_pool_name_dict = dict((
+ (k, f'nn.functional.{v.__name__}') for k, v in unpool_name_to_pool_method_dict.items()
+ ))
+
+ pool_dim = unpool_name_to_dim[op_info.name]
+ pool_method = unpool_name_to_pool_method_dict[op_info.name]
+
+ pool_op_info = copy.copy(op_info)
+ pool_op_info.name = unpool_to_pool_name_dict[op_info.name]
+
+ for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs):
+ # shapes (C, ...) do not work as of now,
+ # see https://github.com/pytorch/pytorch/issues/68337
+ # TODO: remove once the issue is resolved
+ if sample.input.dim() != pool_dim + 2:
+ continue
+
+ # No dilation > 1 for max_unpool,
+ # see https://github.com/pytorch/pytorch/issues/68420
+ if sample.kwargs['dilation'] != 1:
+ continue
+
+ # Can't unpool without indices
+ if sample.kwargs['return_indices']:
+ pool, indices = pool_method(sample.input, **sample.kwargs)
+ # arg has to be a leaf
+ arg = pool.detach().requires_grad_(requires_grad)
+ sample_kwargs = {
+ 'kernel_size': sample.kwargs['kernel_size'],
+ 'stride': sample.kwargs['stride'],
+ 'padding': sample.kwargs['padding'],
+ # output_size could be None but we specify it explicitly
+ # to compensate for the information lose in pool due
+ # to the floor/ceil operation used to compute the shapes
+ 'output_size': sample.input.size()
+ }
+
+ yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs)
+
+def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs):
+ for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+ indices = sample.args[0]
+ # The samples for max_unpool are generated with max_pool.
+ # It could be that a single element from the max_pool's
+ # input is mapped to several locations in its output.
+ # This situation leads to failed gradchecks because
+ # the finite difference algorithm perturbes the elements
+ # of the output one by one, and not in classes of
+ # equivalences determined by whether two elements
+ # in the output are coming from the same location in the
+ # input (simply put, they have the same corresponding index).
+ # So, there are two ways to resolve this issue:
+ # 1. Extract a pertubation for one element and apply it all
+ # the elements from the same equivalence class, or
+ # 2. Make sure that the equivalence classes are all singletons,
+ # i.e. the index tensor has to be comprised of only unique
+ # indices.
+ # Here we go with the solution 2, the easiest of all.
+ if indices.unique().numel() == indices.numel():
+ yield sample
foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp'),
@@ -11984,6 +12057,87 @@
# TODO: investigate nondeterminism
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=sample_inputs_max_pool),
+ OpInfo('nn.functional.max_unpool1d',
+ aten_name='max_unpool1d',
+ supports_autograd=True,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool,
+ skips=(
+ # Jacobian mismatch
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+ # Backward is not reentrant
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+ )),
+ OpInfo('nn.functional.max_unpool1d',
+ variant_test_name='grad',
+ aten_name='max_unpool1d',
+ supports_autograd=True,
+ supports_forward_ad=True,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool_grad,
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+ )),
+ OpInfo('nn.functional.max_unpool2d',
+ aten_name='max_unpool2d',
+ supports_autograd=True,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool,
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+ # Backward is not reentrant
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+ # Jacobian mismatch
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+ )),
+ OpInfo('nn.functional.max_unpool2d',
+ variant_test_name='grad',
+ aten_name='max_unpool2d',
+ supports_autograd=True,
+ supports_forward_ad=True,
+ # Vmap is not happy with non-contiguous (channels_last) inputs
+ check_batched_grad=False,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool_grad,
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+ )),
+ OpInfo('nn.functional.max_unpool3d',
+ aten_name='max_unpool3d',
+ supports_autograd=False,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool,
+ skips=(
+ # Jacobian mismatch
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
+ )),
+ OpInfo('nn.functional.max_unpool3d',
+ variant_test_name='grad',
+ aten_name='max_unpool3d',
+ supports_autograd=False,
+ supports_forward_ad=True,
+ supports_out=False,
+ assert_jit_shape_analysis=False,
+ dtypes=floating_types(),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_max_unpool_grad),
OpInfo('nn.functional.linear',
aten_name='linear',
supports_autograd=True,