[complex] add autograd support for torch.polar (#52488)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/33152
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52488
Reviewed By: zou3519
Differential Revision: D26711841
Pulled By: anjali411
fbshipit-source-id: b8538fb8cb44456b832e4f993cf41954b3ddd2e8
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 166641b..a03ddf8 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -325,8 +325,7 @@
imag: at::imag(grad)
- name: polar(Tensor abs, Tensor angle) -> Tensor
- abs: not_implemented("polar abs")
- angle: not_implemented("polar angle")
+ abs, angle: polar_backward(grad, result)
- name: _conj(Tensor self) -> Tensor
self: grad.conj()
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 5a92bc2..5d53242 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -91,7 +91,7 @@
'reflection_pad1d_backward', 'reflection_pad2d_backward',
'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
- 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace'
+ 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar'
}
# Some operators invalidate the grad_accumulator. Let's reset it.
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 80461fa..719e3cf 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -2987,6 +2987,19 @@
return false;
}
+std::tuple<Tensor, Tensor> polar_backward(
+ const Tensor& grad,
+ const Tensor& result) {
+ Tensor grad_abs, grad_angle;
+ if (grad.defined()) {
+ auto grad_conj = grad.conj();
+ grad_abs = at::real(grad_conj * at::sgn(result));
+ auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0});
+ grad_angle = at::real(grad_conj * result_mul_1_j);
+ }
+ return std::make_tuple(grad_abs, grad_angle);
+}
+
} // namespace details
} // namespace generated
} // namespace autograd
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index 546736a..ea64a01 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -221,7 +221,9 @@
IntArrayRef normalized_shape,
double eps,
std::array<bool, 3> grad_input_mask);
-
+std::tuple<Tensor, Tensor> polar_backward(
+ const Tensor& grad,
+ const Tensor& result);
} // namespace details
} // namespace generated
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 16ad02d..56a723f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1350,6 +1350,18 @@
return samples
+
+def sample_inputs_polar(op_info, device, dtype, requires_grad):
+ def _make_tensor_helper(shape, low=None, high=None):
+ return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
+
+ samples = (
+ SampleInput((_make_tensor_helper((S, S), low=0), _make_tensor_helper((S, S)))),
+ SampleInput((_make_tensor_helper((), low=0), _make_tensor_helper(()))),
+ )
+
+ return samples
+
# Operator database (sorted alphabetically)
op_db: List[OpInfo] = [
UnaryUfuncInfo('abs',
@@ -2396,6 +2408,10 @@
# cuda gradchecks are very slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
+ OpInfo('polar',
+ dtypes=floating_types(),
+ test_inplace_grad=False,
+ sample_inputs_func=sample_inputs_polar),
OpInfo('pinverse',
op=torch.pinverse,
dtypes=floating_and_complex_types(),