Fixes cat backward formula to return correct gradient values for R -> C case (#51681)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51681
Fixes https://github.com/pytorch/pytorch/issues/51627
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D26238748
Pulled By: anjali411
fbshipit-source-id: 1dc47f8ddddbf3f2c176f21e5dcee917f84f4c93
diff --git a/test/test_autograd.py b/test/test_autograd.py
index bf125c6..592164d 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -7807,6 +7807,18 @@
tb_str = "\n".join(traceback.format_tb(tb))
self.assertTrue('raise ValueError("something")' in tb_str)
+ # TODO(@anjali411): add an OpInfo based test for torch.cat
+ # Issue: https://github.com/pytorch/pytorch/issues/51627
+ def test_cat_r_to_c(self):
+ inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
+ inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
+
+ def fn(x1, x2):
+ return torch.cat((x1, x2), dim=-1)
+
+ torch.autograd.gradcheck(fn, [inp_r, inp_c])
+ torch.autograd.gradcheck(fn, [inp_c, inp_r])
+
for test in method_tests():
add_test(*test)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 60e314a..21eb1e2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -283,7 +283,7 @@
mat2: at::_bmm(self.transpose(1, 2), grad, deterministic)
- name: cat(Tensor[] tensors, int dim=0) -> Tensor
- tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
+ tensors: cat_tensors_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors), dim)
- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py
index d5c742b..c74232f 100644
--- a/tools/autograd/load_derivatives.py
+++ b/tools/autograd/load_derivatives.py
@@ -279,6 +279,11 @@
'suffix': '_args_sizes',
'type': 'std::vector<std::vector<int64_t>>',
}),
+ # replace to_args_scalartypes(self) with self_args_scalartypes
+ (r'to_args_scalartypes\({}\)', {
+ 'suffix': '_args_scalartypes',
+ 'type': 'std::vector<ScalarType>',
+ }),
# replace TensorGeometry(self) with self_geometry
(r'TensorGeometry\({}\)', {
'suffix': '_geometry',
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 21e9b44..552f18b 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -553,23 +553,36 @@
return self;
}
-std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
+std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim) {
std::vector<Tensor> grad_inputs(sizes.size());
if (!grad.defined()) {
return grad_inputs;
}
dim = at::legacy_cat_wrap_dim(dim, sizes);
int64_t accumulate = 0;
+
+ Tensor grad_;
+ bool grad_is_complex = grad.is_complex();
+ if (grad_is_complex) {
+ grad_ = at::real(grad);
+ }
for (size_t i = 0; i < sizes.size(); ++i) {
+ Tensor grad_val;
+ if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
+ // R -> C
+ grad_val = grad_;
+ } else {
+ grad_val = grad;
+ }
auto& shape = sizes[i];
// If input was empty tensor, gradInput should be empty tensor.
if (shape == std::vector<int64_t>({0})) {
- grad_inputs[i] = at::zeros({0}, grad.options());
+ grad_inputs[i] = at::zeros({0}, grad_val.options());
continue;
}
auto size = shape[dim];
accumulate += size;
- grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
+ grad_inputs[i] = grad_val.narrow(dim, accumulate - size, size);
}
return grad_inputs;
}
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 4cce8cf..997c75e 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -76,7 +76,7 @@
at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor & self, at::IntArrayRef sizes);
at::Tensor unsqueeze_to(const at::Tensor & self, int64_t dim, at::IntArrayRef sizes);
-std::vector<at::Tensor> cat_tensors_backward(const at::Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim);
+std::vector<at::Tensor> cat_tensors_backward(const at::Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim);
at::Tensor clamp_backward(const at::Tensor & grad, const at::Tensor &self, const optional<at::Scalar> & min, const optional<at::Scalar> & max);
at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name);
at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha);
diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h
index 85b83f2..2221533 100644
--- a/torch/csrc/autograd/VariableTypeUtils.h
+++ b/torch/csrc/autograd/VariableTypeUtils.h
@@ -309,4 +309,11 @@
return args_sizes;
}
+inline std::vector<ScalarType> to_args_scalartypes(TensorList tensors) {
+ std::vector<ScalarType> args_scalartypes(tensors.size());
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ args_scalartypes[i] = tensors[i].scalar_type();
+ }
+ return args_scalartypes;
+}
}} // namespace torch::autograd