Fix complex to real casting warning in _to_copy backward
Fixes #75781
A Real->Complex cast should result in a gradient with no imaginary
component, so discarding the imaginary component is expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75805
Approved by: https://github.com/albanD
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 408c71a..ce34469 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -7828,6 +7828,16 @@
self.assertNotWarn(do_test)
+ def test_to_r_to_c(self, device):
+ def do_test():
+ inp_r = torch.randn(3, 2, dtype=torch.double, device=device,
+ requires_grad=True)
+ out = inp_r.to(torch.complex128)
+ out.sum().backward()
+ self.assertEqual(inp_r.grad, torch.ones_like(inp_r))
+
+ self.assertNotWarn(do_test)
+
def test_non_differentiable_ops(self, device):
# Just make sure the op doesn't raise an error
# and resulting tensor has requires_grad=False.
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 23b6137..168ff77 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -402,7 +402,7 @@
result: auto_linear
- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
- self: grad.to(self.options(), /*non_blocking*/false, /*copy*/false)
+ self: _to_copy_backward(grad, self.options())
result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format)
# The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype)
# (If dtype IS nullopt, we rely on the regular check that any input requires grad).
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index a2d123a..3519e66 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -5398,6 +5398,16 @@
}
+Tensor _to_copy_backward(const Tensor &grad_, const c10::TensorOptions &self_options) {
+ // Handle R->C copies without raising a warning
+ const auto self_type = self_options.dtype().toScalarType();
+ auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grad_);
+ if (!c10::isComplexType(self_type) && grad->is_complex()) {
+ grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grad_));
+ }
+
+ return grad->to(self_options, /*non_blocking=*/false, /*copy=*/false);
+}
} // namespace details
} // namespace generated
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 5938e3f..bb93af1 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -492,6 +492,7 @@
const Tensor& result
);
+Tensor _to_copy_backward(const Tensor &grad, const c10::TensorOptions &self_options);
} // namespace details
} // namespace generated