Make atan2 backwards reuse intermediate computation.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 62bcf44..d9d4edc 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -61,8 +61,7 @@
self: grad * (self * self + 1).reciprocal()
- name: atan2(Tensor self, Tensor other)
- self: grad * other * ((self * self + other * other).reciprocal())
- other: grad * -self * ((self * self + other * other).reciprocal())
+ self, other: atan2_backward(grad, self, other, output_mask)
- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1)
self: maybe_multiply(grad, beta)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index cf50f2d..1b80510 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -413,6 +413,13 @@
return src;
}
+std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
+ auto recip = (self * self + other * other).reciprocal();
+ return std::tuple<Tensor,Tensor>{
+ output_mask[0] ? grad * other * recip : Tensor(),
+ output_mask[1] ? grad * -self * recip : Tensor() };
+}
+
}
${autograd_function_definitions}