[autograd] disable backward/grad for complex scalar output (#92753)
Fixes https://github.com/pytorch/pytorch/issues/92750
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92753
Approved by: https://github.com/ezyang
diff --git a/test/autograd/test_complex.py b/test/autograd/test_complex.py
index c8796a4..5162e03 100644
--- a/test/autograd/test_complex.py
+++ b/test/autograd/test_complex.py
@@ -15,11 +15,11 @@
x1 = torch.view_as_complex(x0)
x2 = torch.view_as_real(x1)
x2.mul_(2)
- x2.sum().backward()
+ x2.sum().abs().backward()
y0 = y.clone()
y0.mul_(2)
- y0.sum().backward()
+ y0.sum().abs().backward()
self.assertEqual(x.grad, y.grad)
@@ -35,11 +35,11 @@
x0 = fn(x)
x0.mul_(2)
- x0.sum().backward()
+ x0.sum().abs().backward()
y0 = fn(y)
y1 = y0.mul(2)
- y1.sum().backward()
+ y1.sum().abs().backward()
self.assertEqual(x.grad, y.grad)
@@ -55,11 +55,11 @@
x0 = fn(x)
x0.mul_(2)
- x0.sum().backward()
+ x0.sum().abs().backward()
y0 = fn(y)
y1 = y0.mul(2)
- y1.sum().backward()
+ y1.sum().abs().backward()
self.assertEqual(x.grad, y.grad)
diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp
index 78d629f..2c4352e 100644
--- a/test/cpp/api/tensor.cpp
+++ b/test/cpp/api/tensor.cpp
@@ -1099,6 +1099,13 @@
y.backward(), "grad can be implicitly created only for scalar outputs");
}
+TEST(TensorTest, BackwardComplexScalarOutput) {
+ auto x = torch::randn({5, 5}, torch::requires_grad());
+ auto y = (x * c10::Scalar(c10::complex<float>(0, 0.5))).sum();
+ ASSERT_THROWS_WITH(
+ y.backward(), "grad can be computed only for real scalar outputs");
+}
+
TEST(TensorTest, IsLeaf) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
diff --git a/test/cpp_api_parity/module_impl_check.py b/test/cpp_api_parity/module_impl_check.py
index 6e44809..bbfad91 100644
--- a/test/cpp_api_parity/module_impl_check.py
+++ b/test/cpp_api_parity/module_impl_check.py
@@ -65,7 +65,11 @@
write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
// Backward pass
- cpp_output.sum().backward();
+ if (cpp_output.is_complex()) {
+ cpp_output.sum().abs().backward();
+ } else {
+ cpp_output.sum().backward();
+ }
// Put all gradients into a c10::Dict, save it into a file to be compared in Python later
c10::Dict<std::string, torch::Tensor> grad_dict;
@@ -109,7 +113,10 @@
script_module = torch.jit.trace(module, torch.tensor(0))
# Backward pass
- python_output.sum().backward()
+ if python_output.dtype.is_complex:
+ python_output.sum().abs().backward()
+ else:
+ python_output.sum().backward()
# Put all gradients into a dict, to be compared later
python_grad_dict = {}
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 4009e37..508d0d1 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2549,10 +2549,10 @@
flat_out, _ = pytree.tree_flatten(out)
sm = 0
for i in flat_out:
- sm += i.sum()
+ sm += i.sum().abs()
sm.backward()
else:
- out.sum().backward()
+ out.sum().abs().backward()
def reset_grads():
def f(x):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 3ca8839..c75ef62 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -369,9 +369,10 @@
assert not x.is_conj()
y = x.conj()
assert y.is_conj()
- return y
+ return y.abs()
res = grad(foo)(x)
- self.assertEqual(res, torch.ones_like(res))
+ with torch.no_grad():
+ self.assertEqual(res, torch.ones_like(res) * torch.sgn(x))
def test_composed_with_autograd(self, device):
x = torch.randn([], requires_grad=True, device=device)
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index c1dec9a..0e4d807 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -394,7 +394,7 @@
tol1('masked.cumprod',
{torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1('svd_lowrank',
- {torch.float32: tol(atol=3e-05, rtol=3e-05)}, device_type='cuda'),
+ {torch.float32: tol(atol=3e-05, rtol=3e-04)}, device_type='cuda'),
tol1('linalg.tensorsolve',
{torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),
))
@@ -430,10 +430,15 @@
if sample.output_process_fn_grad is not None:
result = sample.output_process_fn_grad(result)
+ def abs_if_complex(t):
+ if t.dtype.is_complex:
+ return t.abs()
+ return t
+
# Reduce into single value for grad
if isinstance(result, torch.Tensor):
- return result.sum()
- result = sum([res.sum() for res in result])
+ return abs_if_complex(result.sum())
+ result = sum([abs_if_complex(res.sum()) for res in result])
return result
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py
index 5413513..f35a777 100644
--- a/test/nn/test_convolution.py
+++ b/test/nn/test_convolution.py
@@ -1270,24 +1270,24 @@
# Symmetric padding
z = F.conv1d(x, y, padding=3, dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv1d(x, y, padding='same', dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
x.grad, y.grad = None, None
# Asymmetric padding
z = F.conv1d(x, y, padding=2)[..., 1:]
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv1d(x, y, padding='same')
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
@@ -1299,12 +1299,12 @@
# Symmetric padding
z = F.conv2d(x, y, padding=(3, 4), dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv2d(x, y, padding='same', dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
x.grad, y.grad = None, None
@@ -1312,12 +1312,12 @@
# Asymmetric padding
y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv2d(x, y, padding='same')
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
@@ -1331,12 +1331,12 @@
# Symmetric padding
z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv3d(x, y, padding='same', dilation=2)
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
x.grad, y.grad = None, None
@@ -1351,12 +1351,12 @@
# Asymmetric padding
y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
- z.sum().backward()
+ z.sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
z = F.conv3d(x, y, padding='same')
- z.sum().backward()
+ z.sum().abs().backward()
self.assertEqual(gx_expect, x.grad)
self.assertEqual(gy_expect, y.grad)
@@ -1372,11 +1372,11 @@
# Test F.conv1d gradients work with padding='valid'
x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
- F.conv1d(x, y, padding=0).sum().backward()
+ F.conv1d(x, y, padding=0).sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
- F.conv1d(x, y, padding='valid').sum().backward()
+ F.conv1d(x, y, padding='valid').sum().abs().backward()
gx_actual, gy_actual = x.grad, y.grad
self.assertEqual(gx_expect, gx_actual)
self.assertEqual(gy_expect, gy_actual)
@@ -1510,11 +1510,11 @@
# Test F.conv2d gradients work with padding='valid'
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
- F.conv2d(x, y, padding=0).sum().backward()
+ F.conv2d(x, y, padding=0).sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
- F.conv2d(x, y, padding='valid').sum().backward()
+ F.conv2d(x, y, padding='valid').sum().abs().backward()
gx_actual, gy_actual = x.grad, y.grad
self.assertEqual(gx_expect, gx_actual)
self.assertEqual(gy_expect, gy_actual)
@@ -1526,11 +1526,11 @@
# Test F.conv3d gradients work with padding='valid'
x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
- F.conv3d(x, y, padding=0).sum().backward()
+ F.conv3d(x, y, padding=0).sum().abs().backward()
gx_expect, gy_expect = x.grad, y.grad
x.grad, y.grad = None, None
- F.conv3d(x, y, padding='valid').sum().backward()
+ F.conv3d(x, y, padding='valid').sum().abs().backward()
gx_actual, gy_actual = x.grad, y.grad
self.assertEqual(gx_expect, gx_actual)
self.assertEqual(gy_expect, gy_actual)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 9233a4e..1c28185 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6455,7 +6455,7 @@
with self.assertRaisesRegex(RuntimeError, err_msg):
fn(a, b)
else:
- fn(a, b).backward()
+ fn(a, b).abs().backward()
expected_called = 1
expected_ga_nz = True
@@ -6809,11 +6809,14 @@
def test_named_tensor_for_complex_views(self):
names = ["batch", "height", "width", "complex"]
- z = torch.ones((5, 12, 14, 2), requires_grad=True)
+ z = torch.ones((2, 1, 2, 2), requires_grad=True)
z_named = z.refine_names(*names)
z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(*names[:-1])
- z_complex.sum().backward()
- self.assertEqual(z.grad, torch.view_as_real(torch.ones_like(z_complex).rename(None)))
+ z_complex.sum().abs().backward()
+ expected = torch.ones_like(z_complex).rename(None)
+ abs_1_1j = abs(1 + 1j)
+ expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2))
+ self.assertEqual(z.grad, torch.view_as_real(expected))
def test_custom_function_return_view_in_nograd(self):
class Alias(Function):
@@ -8922,15 +8925,15 @@
# sparse first
x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
- (fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().backward()
+ (fn.apply(x, sparse_grad1) + fn.apply(x, dense_grad) + fn.apply(x, sparse_grad2)).sum().abs().backward()
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
# dense first
x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
- (fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
+ (fn.apply(x, dense_grad) + fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
# sparse only
x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
- (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().backward()
+ (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
# autograd tests via common_method_invocations don't allow input tensors to
@@ -9637,8 +9640,10 @@
def do_test():
out_c.copy_(inp_r)
- out_c.sum().backward()
- self.assertEqual(inp_r.grad, torch.ones_like(inp_r))
+ out_c_inter = out_c.sum()
+ out_c_inter.abs().backward()
+ with torch.no_grad():
+ self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real)
self.assertNotWarn(do_test)
@@ -9647,8 +9652,10 @@
inp_r = torch.randn(3, 2, dtype=torch.double, device=device,
requires_grad=True)
out = inp_r.to(torch.complex128)
- out.sum().backward()
- self.assertEqual(inp_r.grad, torch.ones_like(inp_r))
+ out_inter = out.sum()
+ out_inter.abs().backward()
+ with torch.no_grad():
+ self.assertEqual(inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real)
self.assertNotWarn(do_test)
@@ -9672,6 +9679,17 @@
with self.assertWarnsRegex(UserWarning, "Warn from backward"):
b.backward()
+ def test_complex_scalar_backward(self, device):
+ a = torch.zeros(1, device=device, requires_grad=True)
+ b = a * 0.5j
+
+ msg = "grad can be implicitly created only for real scalar outputs"
+ with self.assertRaisesRegex(RuntimeError, msg):
+ b.backward()
+
+ with self.assertRaisesRegex(RuntimeError, msg):
+ torch.autograd.grad(b, a)
+
def test_pow_real_negative_base_complex_exponent(self, device):
# OpInfo doesn't naturally support input of mixed types, hence this test here.
base = -torch.ones(2, device=device, dtype=torch.double)
@@ -9819,14 +9837,14 @@
b = a.conj()
out = (b**2).sum()
a.sin_()
- out.backward()
+ out.abs().backward()
a = torch.tensor([1 + 1j], requires_grad=True).clone()
b = a.conj()
out = (b**2).sum()
# in this case, it is no longer a view it seems
b.sin_()
- out.backward()
+ out.abs().backward()
def test_with_out_variant(self):
with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
diff --git a/test/test_linalg.py b/test/test_linalg.py
index b44917a..a81452f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -2470,18 +2470,18 @@
A = make_arg((3, 3))
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
U, _, Vh = torch.linalg.svd(A, full_matrices=False)
- (U + Vh).sum().backward()
+ (U + Vh).sum().abs().backward()
A = make_arg((3, 3))
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
V = torch.linalg.eig(A).eigenvectors
- V.sum().backward()
+ V.sum().abs().backward()
A = make_arg((3, 3))
A = A + A.mH
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
Q = torch.linalg.eigh(A).eigenvectors
- Q.sum().backward()
+ Q.sum().abs().backward()
@skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case
@skipCUDAIfRocm
diff --git a/test/test_ops.py b/test/test_ops.py
index c6dd0c3..e2846a0 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1052,7 +1052,10 @@
if isinstance(
expected_forward, torch.Tensor
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
- output_process_fn_grad(expected_forward).sum().backward()
+ out = output_process_fn_grad(expected_forward).sum()
+ if out.dtype.is_complex:
+ out = out.abs()
+ out.backward()
expected_grad = tensor.grad
# Test eager consistency
@@ -1097,7 +1100,10 @@
if expected_grad is not None and (
variant not in inplace_ops or op.supports_inplace_autograd
):
- output_process_fn_grad(variant_forward).sum().backward()
+ out = output_process_fn_grad(variant_forward).sum()
+ if out.dtype.is_complex:
+ out = out.abs()
+ out.backward()
self.assertEqual(expected_grad, tensor.grad)
_test_consistency_helper(samples, variants)
@@ -1565,8 +1571,8 @@
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
- expected_forward.sum().backward(retain_graph=True)
- forward_with_mathview.sum().backward(retain_graph=True)
+ expected_forward.sum().abs().backward(retain_graph=True)
+ forward_with_mathview.sum().abs().backward(retain_graph=True)
if tensor.grad is not None:
cloned1_tensor = (
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
diff --git a/test/test_optim.py b/test/test_optim.py
index b2ddad4..2b0e508 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -430,8 +430,8 @@
optim1.zero_grad()
optim2.zero_grad()
a2 = torch.complex(a1_real, a1_imag)
- f(a1).backward()
- f(a2).backward()
+ f(a1).abs().backward()
+ f(a2).abs().backward()
self.assertEqual(a1.grad.real, a1_real.grad)
self.assertEqual(a1.grad.imag, a1_imag.grad)
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 9327d59..731df68 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -3885,8 +3885,10 @@
self.assertEqual(a.sum(), a._values().sum())
if dtype.is_floating_point or dtype.is_complex:
a.requires_grad_(True)
- a.sum().backward()
- self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device))
+ a_inter = a.sum()
+ a_inter.abs().backward()
+ with torch.no_grad():
+ self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device) * torch.sgn(a_inter))
for shape in [(10, 5), (10, 10)]:
run_test(shape, 0)
run_test(shape, max(shape))
@@ -4558,8 +4560,8 @@
if op.name == 'sum':
count += 1
- r.backward()
- self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device))
+ r.abs().backward()
+ self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device) * torch.sgn(r))
else:
self.skipTest('NOT IMPL')
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py
index 84fec20..c71e36c 100644
--- a/torch/autograd/__init__.py
+++ b/torch/autograd/__init__.py
@@ -86,6 +86,10 @@
if out.requires_grad:
if out.numel() != 1:
raise RuntimeError("grad can be implicitly created only for scalar outputs")
+ if not out.dtype.is_floating_point:
+ msg = ("grad can be implicitly created only for real scalar outputs"
+ f" but got {out.dtype}")
+ raise RuntimeError(msg)
new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format))
else:
new_grads.append(None)
diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp
index 8381032..b81e5be 100644
--- a/torch/csrc/autograd/autograd.cpp
+++ b/torch/csrc/autograd/autograd.cpp
@@ -37,6 +37,10 @@
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
+ TORCH_CHECK(
+ c10::isFloatingType(output.scalar_type()),
+ "grad can be computed only for real scalar outputs but got ",
+ output.scalar_type());
new_grads.emplace_back(
at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
@@ -57,6 +61,10 @@
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
+ TORCH_CHECK(
+ c10::isFloatingType(output.scalar_type()),
+ "grad can be computed only for real scalar outputs but got ",
+ output.scalar_type());
new_grads.emplace_back(
at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py
index 30e3207..25b7bd8 100644
--- a/torch/testing/_internal/common_jit.py
+++ b/torch/testing/_internal/common_jit.py
@@ -51,7 +51,7 @@
def allSum(vs):
if isinstance(vs, torch.Tensor):
vs = (vs,)
- return sum((i + 1) * v.sum()
+ return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
for i, v in enumerate(vs)
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 14ad5a4..c60bd4e 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -6033,6 +6033,9 @@
cpu_input = self._get_input()
type_map = {torch.double: torch.float}
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
+
+ is_any_input_complex = any(map(lambda t: isinstance(t, torch.Tensor) and t.dtype.is_complex, cpu_input_tuple))
+
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
cpu_module = self.constructor(*self.constructor_args)
@@ -6093,12 +6096,19 @@
# torch.autograd.grad doesn't complain that some inputs
# are unreachable (which can happen if you differentiate
# only on the gradient.
+ if is_any_input_complex:
+ outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
+ outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
+ else:
+ outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
+ outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
+
cpu_gg = torch.autograd.grad(
- cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs),
+ outputs_cpu,
cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
retain_graph=True)
gpu_gg = torch.autograd.grad(
- gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs),
+ outputs_gpu,
gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
retain_graph=True)
test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)