Fix `cdist` backward calculation for `p=2` (#37337)

Summary:
Closes https://github.com/pytorch/pytorch/issues/37154

Fixes a bug in `cdist` backward with `p=2`.
Under some circumstances, if the output has 0s, the gradient calculation of `sqrt` will be undefined. Leading to NaNs in the input gradients.

This PR defines a subgradient for this case.

A test is also added to verify this behavior, I was only able to reproduce it under certain shapes, so the shape is explicitly taken from https://github.com/pytorch/pytorch/issues/37154 example
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37337

Differential Revision: D21403178

Pulled By: albanD

fbshipit-source-id: deef9678c1958524b552504920f19617f9ad1da6
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index dd6b329..c5b7ad2 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -530,6 +530,7 @@
 _(aten, orgqr) \
 _(aten, ormqr) \
 _(aten, pairwise_distance) \
+_(aten, _euclidean_dist) \
 _(aten, pdist) \
 _(aten, cdist) \
 _(aten, permute) \
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index 4d32b90..b2b7605 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -25,7 +25,10 @@
   return at::_pdist_forward(self.contiguous(), p);
 }
 
-Tensor euclidean_dist_out(const Tensor& x1, const Tensor& x2) {
+Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) {
+  /** This function does the fist part of the euclidean distance calculation
+   * We divide it in two steps to simplify dealing with subgradients in the 
+   * backward step */
   Tensor x1_norm = x1.pow(2).sum(-1, true);
   Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
   Tensor x2_norm = x2.pow(2).sum(-1, true);
@@ -87,8 +90,8 @@
   } else if (c1 == 0) {
     result = at::zeros(output_shape, x1.options());
   } else if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
-    Tensor dist = (expand_batch_product == 1) ? euclidean_dist_out(x1, x2) :
-                  euclidean_dist_out(tensor1_expanded, tensor2_expanded);
+    Tensor dist = (expand_batch_product == 1) ? at::_euclidean_dist(x1, x2) :
+                  at::_euclidean_dist(tensor1_expanded, tensor2_expanded);
     result = dist.view(output_shape);
   } else {
     result = at::empty(output_shape, x1.options());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 0fc4bbd..1477835 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2249,6 +2249,9 @@
   use_c10_dispatcher: full
   supports_named_tensor: True
 
+- func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+  use_c10_dispatcher: full
+
 - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
   use_c10_dispatcher: full
   supports_named_tensor: True
diff --git a/test/test_autograd.py b/test/test_autograd.py
index a2009df..cf54de8 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5300,7 +5300,6 @@
             x = x - (((x - y) < eps).float() * 2 * eps)
             x.requires_grad = True
             y.requires_grad = True
-            f_args_variable = (x, y)
             dist = torch.cdist(x, y, p=2)
             # Do a backward pass to check that it is valid for large
             # matrices
@@ -5315,6 +5314,21 @@
         _test_cdist_for_size((1, 1), (S, 1))
         _test_euclidean_large_cdist((2000, 5))
 
+    def test_cdist_same_inputs(self, device):
+        # Test to detect issues in cdist gradient calculation
+        # When the distances are 0
+        sizex = (1, 27, 32)
+        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+            x = torch.randn(sizex, device=device, dtype=torch.float)
+            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+            y = x.clone()
+            eps = 1e-6
+            x.requires_grad = True
+            d = torch.cdist(x, y)
+            d.backward(dist_grad)
+            # Check that the backward passs does not contain invalid 
+            # values such as nan or inf
+            assert torch.isfinite(x.grad).all()
 
     def test_parameter_resize(self, device):
         asd = torch.nn.Parameter(torch.ones(16, device=device))
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 1a25d26..ad4645b 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -669,6 +669,9 @@
   self: not_implemented("_pdist_backward")
   pdist: not_implemented("_pdist_backward")
 
+- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+  x1, x2: _euclidean_dist_backward(grad, x1, x2, result)
+
 - name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
   x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
   x2: _cdist_backward(grad.transpose(-1, -2).contiguous(), x2, x1, p, result.transpose(-1, -2).contiguous())
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 7e17832..7e128e1 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -90,6 +90,15 @@
   return size;
 }
 
+std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
+  // handle case at 0 where we return a subgradient containing 0
+  Tensor ratio = grad / res;
+  ratio.masked_fill_(res == 0, 0);
+  return std::tuple<Tensor, Tensor>{
+            x1 * ratio.sum(-1, true) - ratio.matmul(x2),
+            x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)};
+}
+
 Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
   double p = p_.value_or(2.0).toDouble();
   Tensor self_scaled;