[complex] conv_transpose3d : complex support (#87967)

Reference: https://github.com/pytorch/pytorch/issues/71108

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87967
Approved by: https://github.com/anjali411
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 6021580..109f0ac 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -1087,8 +1087,14 @@
   Tensor input;
   bool is_batched;
   std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
-  auto output = at::convolution(
+  Tensor output;
+  if (at::isComplexType(input_.scalar_type())) {
+    output = complex_convolution(
       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
+  } else {
+    output = at::convolution(
+      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
+  }
   return is_batched ? output : output.squeeze(0);
 }
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index e8c7b41..ba2e8bc 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -10673,6 +10673,9 @@
                DecorateInfo(
                    toleranceOverride({torch.complex32: tol(atol=1e-5, rtol=5e-3)}),
                    "TestCudaFuserOpInfo", "test_nvfuser_correctness"),
+               DecorateInfo(
+                   toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
+                   'TestCommon', 'test_numpy_ref_mps'),
            ),
            skips=(
                # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
@@ -10741,8 +10744,12 @@
     OpInfo('nn.functional.conv_transpose3d',
            aten_name='conv_transpose3d',
            aliases=('conv_transpose3d',),
-           dtypes=floating_types_and(torch.int64),
-           dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
+           # `ref` for this function is backward of
+           # corresponding `conv*d`
+           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d),
+           dtypes=floating_and_complex_types_and(torch.int64),
+           dtypesIfCUDA=floating_and_complex_types_and(
+               torch.float16, torch.chalf, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            sample_inputs_func=sample_inputs_conv_transpose3d,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
@@ -10752,24 +10759,45 @@
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
            decorators=[
                DecorateInfo(
-                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
+                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06),
+                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
                    'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
                DecorateInfo(
                    toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }),
                    'TestCompositeCompliance', 'test_operator', device_type='cuda'),
                DecorateInfo(
-                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), }),
+                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06),
+                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
                    'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
                DecorateInfo(
                    toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }),
                    'TestCompositeCompliance', 'test_forward_ad', device_type='cuda',
-                   active_if=TEST_CUDNN)],
+                   active_if=TEST_CUDNN),
+               DecorateInfo(
+                   toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
+                   "TestCudaFuserOpInfo", "test_nvfuser_correctness"),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}),
+                   "TestMathBits", "test_conj_view", device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
+                   'TestCommon', 'test_complex_half_reference_testing')],
            skips=(
                # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
                # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
                DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
                DecorateInfo(unittest.skip("Skipped! 75029"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
                DecorateInfo(unittest.skip("Skipped! 75363"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
+               # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long'
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.int64,)),
+               # Reference: https://github.com/pytorch/pytorch/issues/86356
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.double, torch.cdouble)),
+               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
            ),
            supports_out=False,),
     OpInfo('nn.functional.conv1d',
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index fed908e..1f395cb 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -1179,6 +1179,7 @@
                )),
     ModuleInfo(torch.nn.ConvTranspose3d,
                module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
+               dtypes=floating_and_complex_types_and(torch.chalf),
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                module_memformat_affects_out=True,
                skips=(
@@ -1190,9 +1191,28 @@
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+                   # These fail only on ROCm
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
+                   # Not implmented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
+                                dtypes=(torch.chalf,), device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
+                                dtypes=(torch.chalf,), device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, 'TestModule',
+                                'test_if_train_and_eval_modes_differ', dtypes=(torch.chalf,), device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors',
+                                dtypes=(torch.chalf,), device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_multiple_device_transfer',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+                   # Ref: https://github.com/pytorch/pytorch/issues/73502
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_pickle', dtypes=(torch.chalf,)),
                ),
                decorators=(
                    DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
                )),
     ModuleInfo(torch.nn.ELU,
                module_inputs_func=module_inputs_torch_nn_ELU,