[autograd] disable backward/grad for complex scalar output (#92753)

Fixes https://github.com/pytorch/pytorch/issues/92750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92753
Approved by: https://github.com/ezyang
diff --git a/test/autograd/test_complex.py b/test/autograd/test_complex.py
index c8796a4..5162e03 100644
--- a/test/autograd/test_complex.py
+++ b/test/autograd/test_complex.py
@@ -15,11 +15,11 @@
         x1 = torch.view_as_complex(x0)
         x2 = torch.view_as_real(x1)
         x2.mul_(2)
-        x2.sum().backward()
+        x2.sum().abs().backward()
 
         y0 = y.clone()
         y0.mul_(2)
-        y0.sum().backward()
+        y0.sum().abs().backward()
 
         self.assertEqual(x.grad, y.grad)
 
@@ -35,11 +35,11 @@
 
         x0 = fn(x)
         x0.mul_(2)
-        x0.sum().backward()
+        x0.sum().abs().backward()
 
         y0 = fn(y)
         y1 = y0.mul(2)
-        y1.sum().backward()
+        y1.sum().abs().backward()
 
         self.assertEqual(x.grad, y.grad)
 
@@ -55,11 +55,11 @@
 
         x0 = fn(x)
         x0.mul_(2)
-        x0.sum().backward()
+        x0.sum().abs().backward()
 
         y0 = fn(y)
         y1 = y0.mul(2)
-        y1.sum().backward()
+        y1.sum().abs().backward()
 
         self.assertEqual(x.grad, y.grad)
 
diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp
index 78d629f..2c4352e 100644
--- a/test/cpp/api/tensor.cpp
+++ b/test/cpp/api/tensor.cpp
@@ -1099,6 +1099,13 @@
       y.backward(), "grad can be implicitly created only for scalar outputs");
 }
 
+TEST(TensorTest, BackwardComplexScalarOutput) {
+  auto x = torch::randn({5, 5}, torch::requires_grad());
+  auto y = (x * c10::Scalar(c10::complex<float>(0, 0.5))).sum();
+  ASSERT_THROWS_WITH(
+      y.backward(), "grad can be computed only for real scalar outputs");
+}
+
 TEST(TensorTest, IsLeaf) {
   auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
   auto y = x * x;
diff --git a/test/cpp_api_parity/module_impl_check.py b/test/cpp_api_parity/module_impl_check.py
index 6e44809..bbfad91 100644
--- a/test/cpp_api_parity/module_impl_check.py
+++ b/test/cpp_api_parity/module_impl_check.py
@@ -65,7 +65,11 @@
   write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
 
   // Backward pass
-  cpp_output.sum().backward();
+  if (cpp_output.is_complex()) {
+    cpp_output.sum().abs().backward();
+  } else {
+    cpp_output.sum().backward();
+  }
 
   // Put all gradients into a c10::Dict, save it into a file to be compared in Python later
   c10::Dict<std::string, torch::Tensor> grad_dict;
@@ -109,7 +113,10 @@
     script_module = torch.jit.trace(module, torch.tensor(0))
 
     # Backward pass
-    python_output.sum().backward()
+    if python_output.dtype.is_complex:
+        python_output.sum().abs().backward()
+    else:
+        python_output.sum().backward()
 
     # Put all gradients into a dict, to be compared later
     python_grad_dict = {}
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 4009e37..508d0d1 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2549,10 +2549,10 @@
             flat_out, _ = pytree.tree_flatten(out)
             sm = 0
             for i in flat_out:
-                sm += i.sum()
+                sm += i.sum().abs()
             sm.backward()
         else:
-            out.sum().backward()
+            out.sum().abs().backward()
 
     def reset_grads():
         def f(x):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 3ca8839..c75ef62 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -369,9 +369,10 @@
             assert not x.is_conj()
             y = x.conj()
             assert y.is_conj()
-            return y
+            return y.abs()
         res = grad(foo)(x)
-        self.assertEqual(res, torch.ones_like(res))
+        with torch.no_grad():
+            self.assertEqual(res, torch.ones_like(res) * torch.sgn(x))
 
     def test_composed_with_autograd(self, device):
         x = torch.randn([], requires_grad=True, device=device)
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index c1dec9a..0e4d807 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -394,7 +394,7 @@
         tol1('masked.cumprod',
              {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
         tol1('svd_lowrank',
-             {torch.float32: tol(atol=3e-05, rtol=3e-05)}, device_type='cuda'),
+             {torch.float32: tol(atol=3e-05, rtol=3e-04)}, device_type='cuda'),
         tol1('linalg.tensorsolve',
              {torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),
     ))
@@ -430,10 +430,15 @@
                 if sample.output_process_fn_grad is not None:
                     result = sample.output_process_fn_grad(result)
 
+                def abs_if_complex(t):
+                    if t.dtype.is_complex:
+                        return t.abs()
+                    return t
+
                 # Reduce into single value for grad
                 if isinstance(result, torch.Tensor):
-                    return result.sum()
-                result = sum([res.sum() for res in result])
+                    return abs_if_complex(result.sum())
+                result = sum([abs_if_complex(res.sum()) for res in result])
                 return result
 
             result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py
index 5413513..f35a777 100644
--- a/test/nn/test_convolution.py
+++ b/test/nn/test_convolution.py
@@ -1270,24 +1270,24 @@
 
         # Symmetric padding
         z = F.conv1d(x, y, padding=3, dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv1d(x, y, padding='same', dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
         x.grad, y.grad = None, None
 
         # Asymmetric padding
         z = F.conv1d(x, y, padding=2)[..., 1:]
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv1d(x, y, padding='same')
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
 
@@ -1299,12 +1299,12 @@
 
         # Symmetric padding
         z = F.conv2d(x, y, padding=(3, 4), dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv2d(x, y, padding='same', dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
         x.grad, y.grad = None, None
@@ -1312,12 +1312,12 @@
         # Asymmetric padding
         y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
         z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv2d(x, y, padding='same')
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
 
@@ -1331,12 +1331,12 @@
 
         # Symmetric padding
         z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv3d(x, y, padding='same', dilation=2)
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
         x.grad, y.grad = None, None
@@ -1351,12 +1351,12 @@
         # Asymmetric padding
         y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
         z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
-        z.sum().backward()
+        z.sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
         z = F.conv3d(x, y, padding='same')
-        z.sum().backward()
+        z.sum().abs().backward()
         self.assertEqual(gx_expect, x.grad)
         self.assertEqual(gy_expect, y.grad)
 
@@ -1372,11 +1372,11 @@
         # Test F.conv1d gradients work with padding='valid'
         x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
         y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
-        F.conv1d(x, y, padding=0).sum().backward()
+        F.conv1d(x, y, padding=0).sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
-        F.conv1d(x, y, padding='valid').sum().backward()
+        F.conv1d(x, y, padding='valid').sum().abs().backward()
         gx_actual, gy_actual = x.grad, y.grad
         self.assertEqual(gx_expect, gx_actual)
         self.assertEqual(gy_expect, gy_actual)
@@ -1510,11 +1510,11 @@
         # Test F.conv2d gradients work with padding='valid'
         x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
         y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
-        F.conv2d(x, y, padding=0).sum().backward()
+        F.conv2d(x, y, padding=0).sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
-        F.conv2d(x, y, padding='valid').sum().backward()
+        F.conv2d(x, y, padding='valid').sum().abs().backward()
         gx_actual, gy_actual = x.grad, y.grad
         self.assertEqual(gx_expect, gx_actual)
         self.assertEqual(gy_expect, gy_actual)
@@ -1526,11 +1526,11 @@
         # Test F.conv3d gradients work with padding='valid'
         x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
         y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
-        F.conv3d(x, y, padding=0).sum().backward()
+        F.conv3d(x, y, padding=0).sum().abs().backward()
         gx_expect, gy_expect = x.grad, y.grad
         x.grad, y.grad = None, None
 
-        F.conv3d(x, y, padding='valid').sum().backward()
+        F.conv3d(x, y, padding='valid').sum().abs().backward()
         gx_actual, gy_actual = x.grad, y.grad
         self.assertEqual(gx_expect, gx_actual)
         self.assertEqual(gy_expect, gy_actual)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 9233a4e..1c28185 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6455,7 +6455,7 @@
                         with self.assertRaisesRegex(RuntimeError, err_msg):
                             fn(a, b)
                     else:
-                        fn(a, b).backward()
+                        fn(a, b).abs().backward()
 
                     expected_called = 1
                     expected_ga_nz = True
@@ -6809,11 +6809,14 @@
 
     def test_named_tensor_for_complex_views(self):
         names = ["batch", "height", "width", "complex"]
-        z = torch.ones((5, 12, 14, 2), requires_grad=True)
+        z = torch.ones((2, 1, 2, 2), requires_grad=True)
         z_named = z.refine_names(*names)
         z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(*names[:-1])
-        z_complex.sum().backward()
-        self.assertEqual(z.grad, torch.view_as_real(torch.ones_like(z_complex).rename(None)))
+        z_complex.sum().abs().backward()
+        expected = torch.ones_like(z_complex).rename(None)
+        abs_1_1j = abs(1 + 1j)
+        expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2))
+        self.assertEqual(z.grad, torch.view_as_real(expected))
 
     def test_custom_function_return_view_in_nograd(self):
         class Alias(Function):
@@ -8922,15 +8925,15 @@
 
         # sparse first
         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
-        (fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().backward()
+        (fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().abs().backward()
         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
         # dense first
         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
-        (fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
+        (fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
         # sparse only
         x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
-        (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
+        (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
         self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
 
     # autograd tests via common_method_invocations don't allow input tensors to
@@ -9637,8 +9640,10 @@
 
         def do_test():
             out_c.copy_(inp_r)
-            out_c.sum().backward()
-            self.assertEqual(inp_r.grad, torch.ones_like(inp_r))
+            out_c_inter = out_c.sum()
+            out_c_inter.abs().backward()
+            with torch.no_grad():
+                self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real)
 
         self.assertNotWarn(do_test)
 
@@ -9647,8 +9652,10 @@
             inp_r = torch.randn(3, 2, dtype=torch.double, device=device,
                                 requires_grad=True)
             out = inp_r.to(torch.complex128)
-            out.sum().backward()
-            self.assertEqual(inp_r.grad, torch.ones_like(inp_r))
+            out_inter = out.sum()
+            out_inter.abs().backward()
+            with torch.no_grad():
+                self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real)
 
         self.assertNotWarn(do_test)
 
@@ -9672,6 +9679,17 @@
         with self.assertWarnsRegex(UserWarning, "Warn from backward"):
             b.backward()
 
+    def test_complex_scalar_backward(self, device):
+        a = torch.zeros(1, device=device, requires_grad=True)
+        b = a * 0.5j
+
+        msg = "grad can be implicitly created only for real scalar outputs"
+        with self.assertRaisesRegex(RuntimeError, msg):
+            b.backward()
+
+        with self.assertRaisesRegex(RuntimeError, msg):
+            torch.autograd.grad(b, a)
+
     def test_pow_real_negative_base_complex_exponent(self, device):
         # OpInfo doesn't naturally support input of mixed types, hence this test here.
         base = -torch.ones(2, device=device, dtype=torch.double)
@@ -9819,14 +9837,14 @@
             b = a.conj()
             out = (b**2).sum()
             a.sin_()
-            out.backward()
+            out.abs().backward()
 
             a = torch.tensor([1 + 1j], requires_grad=True).clone()
             b = a.conj()
             out = (b**2).sum()
             # in this case, it is no longer a view it seems
             b.sin_()
-            out.backward()
+            out.abs().backward()
 
     def test_with_out_variant(self):
         with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
diff --git a/test/test_linalg.py b/test/test_linalg.py
index b44917a..a81452f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -2470,18 +2470,18 @@
         A = make_arg((3, 3))
         with self.assertRaisesRegex(RuntimeError, "ill-defined"):
             U, _, Vh = torch.linalg.svd(A, full_matrices=False)
-            (U + Vh).sum().backward()
+            (U + Vh).sum().abs().backward()
 
         A = make_arg((3, 3))
         with self.assertRaisesRegex(RuntimeError, "ill-defined"):
             V = torch.linalg.eig(A).eigenvectors
-            V.sum().backward()
+            V.sum().abs().backward()
 
         A = make_arg((3, 3))
         A = A + A.mH
         with self.assertRaisesRegex(RuntimeError, "ill-defined"):
             Q = torch.linalg.eigh(A).eigenvectors
-            Q.sum().backward()
+            Q.sum().abs().backward()
 
     @skipCUDAIfNoCusolver  # MAGMA backend doesn't work in this case
     @skipCUDAIfRocm
diff --git a/test/test_ops.py b/test/test_ops.py
index c6dd0c3..e2846a0 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1052,7 +1052,10 @@
                 if isinstance(
                     expected_forward, torch.Tensor
                 ) and dtype in op.supported_backward_dtypes(torch.device(device).type):
-                    output_process_fn_grad(expected_forward).sum().backward()
+                    out = output_process_fn_grad(expected_forward).sum()
+                    if out.dtype.is_complex:
+                        out = out.abs()
+                    out.backward()
                     expected_grad = tensor.grad
 
                 # Test eager consistency
@@ -1097,7 +1100,10 @@
                     if expected_grad is not None and (
                         variant not in inplace_ops or op.supports_inplace_autograd
                     ):
-                        output_process_fn_grad(variant_forward).sum().backward()
+                        out = output_process_fn_grad(variant_forward).sum()
+                        if out.dtype.is_complex:
+                            out = out.abs()
+                        out.backward()
                         self.assertEqual(expected_grad, tensor.grad)
 
         _test_consistency_helper(samples, variants)
@@ -1565,8 +1571,8 @@
                     if isinstance(sample.input, torch.Tensor)
                     else sample.input[0]
                 )
-                expected_forward.sum().backward(retain_graph=True)
-                forward_with_mathview.sum().backward(retain_graph=True)
+                expected_forward.sum().abs().backward(retain_graph=True)
+                forward_with_mathview.sum().abs().backward(retain_graph=True)
                 if tensor.grad is not None:
                     cloned1_tensor = (
                         cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
diff --git a/test/test_optim.py b/test/test_optim.py
index b2ddad4..2b0e508 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -430,8 +430,8 @@
             optim1.zero_grad()
             optim2.zero_grad()
             a2 = torch.complex(a1_real, a1_imag)
-            f(a1).backward()
-            f(a2).backward()
+            f(a1).abs().backward()
+            f(a2).abs().backward()
 
             self.assertEqual(a1.grad.real, a1_real.grad)
             self.assertEqual(a1.grad.imag, a1_imag.grad)
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 9327d59..731df68 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -3885,8 +3885,10 @@
             self.assertEqual(a.sum(), a._values().sum())
             if dtype.is_floating_point or dtype.is_complex:
                 a.requires_grad_(True)
-                a.sum().backward()
-                self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device))
+                a_inter = a.sum()
+                a_inter.abs().backward()
+                with torch.no_grad():
+                    self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device) * torch.sgn(a_inter))
         for shape in [(10, 5), (10, 10)]:
             run_test(shape, 0)
             run_test(shape, max(shape))
@@ -4558,8 +4560,8 @@
 
             if op.name == 'sum':
                 count += 1
-                r.backward()
-                self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device))
+                r.abs().backward()
+                self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device) * torch.sgn(r))
             else:
                 self.skipTest('NOT IMPL')
 
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py
index 84fec20..c71e36c 100644
--- a/torch/autograd/__init__.py
+++ b/torch/autograd/__init__.py
@@ -86,6 +86,10 @@
             if out.requires_grad:
                 if out.numel() != 1:
                     raise RuntimeError("grad can be implicitly created only for scalar outputs")
+                if not out.dtype.is_floating_point:
+                    msg = ("grad can be implicitly created only for real scalar outputs"
+                           f" but got {out.dtype}")
+                    raise RuntimeError(msg)
                 new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format))
             else:
                 new_grads.append(None)
diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp
index 8381032..b81e5be 100644
--- a/torch/csrc/autograd/autograd.cpp
+++ b/torch/csrc/autograd/autograd.cpp
@@ -37,6 +37,10 @@
         TORCH_CHECK(
             output.numel() == 1,
             "grad can be implicitly created only for scalar outputs");
+        TORCH_CHECK(
+            c10::isFloatingType(output.scalar_type()),
+            "grad can be computed only for real scalar outputs but got ",
+            output.scalar_type());
         new_grads.emplace_back(
             at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
       }
@@ -57,6 +61,10 @@
           TORCH_CHECK(
               output.numel() == 1,
               "grad can be implicitly created only for scalar outputs");
+          TORCH_CHECK(
+              c10::isFloatingType(output.scalar_type()),
+              "grad can be computed only for real scalar outputs but got ",
+              output.scalar_type());
           new_grads.emplace_back(
               at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
         }
diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py
index 30e3207..25b7bd8 100644
--- a/torch/testing/_internal/common_jit.py
+++ b/torch/testing/_internal/common_jit.py
@@ -51,7 +51,7 @@
     def allSum(vs):
         if isinstance(vs, torch.Tensor):
             vs = (vs,)
-        return sum((i + 1) * v.sum()
+        return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
                    for i, v in enumerate(vs)
                    if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
 
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 14ad5a4..c60bd4e 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -6033,6 +6033,9 @@
         cpu_input = self._get_input()
         type_map = {torch.double: torch.float}
         cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
+
+        is_any_input_complex = any(map(lambda t: isinstance(t, torch.Tensor) and t.dtype.is_complex, cpu_input_tuple))
+
         gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
 
         cpu_module = self.constructor(*self.constructor_args)
@@ -6093,12 +6096,19 @@
             # torch.autograd.grad doesn't complain that some inputs
             # are unreachable (which can happen if you differentiate
             # only on the gradient.
+            if is_any_input_complex:
+                outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
+                outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
+            else:
+                outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
+                outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
+
             cpu_gg = torch.autograd.grad(
-                cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs),
+                outputs_cpu,
                 cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
                 retain_graph=True)
             gpu_gg = torch.autograd.grad(
-                gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs),
+                outputs_gpu,
                 gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
                 retain_graph=True)
             test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)