Enable complex autograd for col2im / im2col (#68199)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68199
Test Plan: Imported from OSS
Reviewed By: VitalyFedyunin
Differential Revision: D32467043
Pulled By: mruberry
fbshipit-source-id: 9094aff036f75b280422e210f7089140ea61fc71
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index c3faf76..9cea6e2 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -107,7 +107,7 @@
'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical',
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
- 'linalg_pinv', 'linalg_lstsq',
+ 'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4e9ec10..9b75c94 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8684,10 +8684,10 @@
skipCPUIfNoFFT,
DecorateInfo(unittest.skip("Skipped! istft does not match the native function"),
'TestJit', 'test_variant_consistency_jit'),
+ # gradcheck fails on ROCm (gh-68429)
+ DecorateInfo(skipCUDAIfRocm, 'TestGradients', 'test_fn_grad'),
],
dtypes=floating_and_complex_types(),
- # FIXME: col2im does not support automatic differentiation for outputs with complex dtype.
- supports_autograd=False,
sample_inputs_func=lambda *a, **kw: list(sample_inputs_istft(*a, **kw)),
check_batched_grad=False,
check_batched_gradgrad=False,
@@ -9827,8 +9827,8 @@
autodiff_nonfusible_nodes=["aten::hardswish"]),
OpInfo('nn.functional.unfold',
aten_name='im2col',
- dtypes=floating_types_and(torch.half),
- dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16),
+ dtypes=floating_and_complex_types_and(torch.half),
+ dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_unfold,
skips=(
# RuntimeError: false