Fixes cat backward formula to return correct gradient values for R -> C case (#51681)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51681

Fixes https://github.com/pytorch/pytorch/issues/51627

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D26238748

Pulled By: anjali411

fbshipit-source-id: 1dc47f8ddddbf3f2c176f21e5dcee917f84f4c93
diff --git a/test/test_autograd.py b/test/test_autograd.py
index bf125c6..592164d 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -7807,6 +7807,18 @@
             tb_str = "\n".join(traceback.format_tb(tb))
             self.assertTrue('raise ValueError("something")' in tb_str)
 
+    # TODO(@anjali411): add an OpInfo based test for torch.cat
+    # Issue: https://github.com/pytorch/pytorch/issues/51627
+    def test_cat_r_to_c(self):
+        inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
+        inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
+
+        def fn(x1, x2):
+            return torch.cat((x1, x2), dim=-1)
+
+        torch.autograd.gradcheck(fn, [inp_r, inp_c])
+        torch.autograd.gradcheck(fn, [inp_c, inp_r])
+
 for test in method_tests():
     add_test(*test)
 
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 60e314a..21eb1e2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -283,7 +283,7 @@
   mat2: at::_bmm(self.transpose(1, 2), grad, deterministic)
 
 - name: cat(Tensor[] tensors, int dim=0) -> Tensor
-  tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
+  tensors: cat_tensors_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors), dim)
 
 - name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
   self: zeros_like(grad)
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py
index d5c742b..c74232f 100644
--- a/tools/autograd/load_derivatives.py
+++ b/tools/autograd/load_derivatives.py
@@ -279,6 +279,11 @@
             'suffix': '_args_sizes',
             'type': 'std::vector<std::vector<int64_t>>',
         }),
+        # replace to_args_scalartypes(self) with self_args_scalartypes
+        (r'to_args_scalartypes\({}\)', {
+            'suffix': '_args_scalartypes',
+            'type': 'std::vector<ScalarType>',
+        }),
         # replace TensorGeometry(self) with self_geometry
         (r'TensorGeometry\({}\)', {
             'suffix': '_geometry',
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 21e9b44..552f18b 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -553,23 +553,36 @@
   return self;
 }
 
-std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
+std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim) {
   std::vector<Tensor> grad_inputs(sizes.size());
   if (!grad.defined()) {
     return grad_inputs;
   }
   dim = at::legacy_cat_wrap_dim(dim, sizes);
   int64_t accumulate = 0;
+
+  Tensor grad_;
+  bool grad_is_complex = grad.is_complex();
+  if (grad_is_complex) {
+    grad_ = at::real(grad);
+  }
   for (size_t i = 0; i < sizes.size(); ++i) {
+    Tensor grad_val;
+    if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
+      // R -> C
+      grad_val = grad_;
+    } else {
+      grad_val = grad;
+    }
     auto& shape = sizes[i];
     // If input was empty tensor, gradInput should be empty tensor.
     if (shape == std::vector<int64_t>({0})) {
-      grad_inputs[i] = at::zeros({0}, grad.options());
+      grad_inputs[i] = at::zeros({0}, grad_val.options());
       continue;
     }
     auto size = shape[dim];
     accumulate += size;
-    grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
+    grad_inputs[i] = grad_val.narrow(dim, accumulate - size, size);
   }
   return grad_inputs;
 }
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 4cce8cf..997c75e 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -76,7 +76,7 @@
 at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
 at::Tensor unsqueeze_to(const at::Tensor & self, at::IntArrayRef sizes);
 at::Tensor unsqueeze_to(const at::Tensor & self, int64_t dim, at::IntArrayRef sizes);
-std::vector<at::Tensor> cat_tensors_backward(const at::Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim);
+std::vector<at::Tensor> cat_tensors_backward(const at::Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim);
 at::Tensor clamp_backward(const at::Tensor & grad, const at::Tensor &self, const optional<at::Scalar> & min, const optional<at::Scalar> & max);
 at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name);
 at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha);
diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h
index 85b83f2..2221533 100644
--- a/torch/csrc/autograd/VariableTypeUtils.h
+++ b/torch/csrc/autograd/VariableTypeUtils.h
@@ -309,4 +309,11 @@
   return args_sizes;
 }
 
+inline std::vector<ScalarType> to_args_scalartypes(TensorList tensors) {
+  std::vector<ScalarType> args_scalartypes(tensors.size());
+  for (size_t i = 0; i < tensors.size(); ++i) {
+    args_scalartypes[i] = tensors[i].scalar_type();
+  }
+  return args_scalartypes;
+}
 }} // namespace torch::autograd