Fix `cdist` backward calculation for `p=2` (#37337)
Summary:
Closes https://github.com/pytorch/pytorch/issues/37154
Fixes a bug in `cdist` backward with `p=2`.
Under some circumstances, if the output has 0s, the gradient calculation of `sqrt` will be undefined. Leading to NaNs in the input gradients.
This PR defines a subgradient for this case.
A test is also added to verify this behavior, I was only able to reproduce it under certain shapes, so the shape is explicitly taken from https://github.com/pytorch/pytorch/issues/37154 example
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37337
Differential Revision: D21403178
Pulled By: albanD
fbshipit-source-id: deef9678c1958524b552504920f19617f9ad1da6
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index dd6b329..c5b7ad2 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -530,6 +530,7 @@
_(aten, orgqr) \
_(aten, ormqr) \
_(aten, pairwise_distance) \
+_(aten, _euclidean_dist) \
_(aten, pdist) \
_(aten, cdist) \
_(aten, permute) \
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index 4d32b90..b2b7605 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -25,7 +25,10 @@
return at::_pdist_forward(self.contiguous(), p);
}
-Tensor euclidean_dist_out(const Tensor& x1, const Tensor& x2) {
+Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) {
+ /** This function does the fist part of the euclidean distance calculation
+ * We divide it in two steps to simplify dealing with subgradients in the
+ * backward step */
Tensor x1_norm = x1.pow(2).sum(-1, true);
Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor x2_norm = x2.pow(2).sum(-1, true);
@@ -87,8 +90,8 @@
} else if (c1 == 0) {
result = at::zeros(output_shape, x1.options());
} else if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
- Tensor dist = (expand_batch_product == 1) ? euclidean_dist_out(x1, x2) :
- euclidean_dist_out(tensor1_expanded, tensor2_expanded);
+ Tensor dist = (expand_batch_product == 1) ? at::_euclidean_dist(x1, x2) :
+ at::_euclidean_dist(tensor1_expanded, tensor2_expanded);
result = dist.view(output_shape);
} else {
result = at::empty(output_shape, x1.options());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 0fc4bbd..1477835 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2249,6 +2249,9 @@
use_c10_dispatcher: full
supports_named_tensor: True
+- func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+ use_c10_dispatcher: full
+
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
use_c10_dispatcher: full
supports_named_tensor: True
diff --git a/test/test_autograd.py b/test/test_autograd.py
index a2009df..cf54de8 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5300,7 +5300,6 @@
x = x - (((x - y) < eps).float() * 2 * eps)
x.requires_grad = True
y.requires_grad = True
- f_args_variable = (x, y)
dist = torch.cdist(x, y, p=2)
# Do a backward pass to check that it is valid for large
# matrices
@@ -5315,6 +5314,21 @@
_test_cdist_for_size((1, 1), (S, 1))
_test_euclidean_large_cdist((2000, 5))
+ def test_cdist_same_inputs(self, device):
+ # Test to detect issues in cdist gradient calculation
+ # When the distances are 0
+ sizex = (1, 27, 32)
+ for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+ y = x.clone()
+ eps = 1e-6
+ x.requires_grad = True
+ d = torch.cdist(x, y)
+ d.backward(dist_grad)
+ # Check that the backward passs does not contain invalid
+ # values such as nan or inf
+ assert torch.isfinite(x.grad).all()
def test_parameter_resize(self, device):
asd = torch.nn.Parameter(torch.ones(16, device=device))
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 1a25d26..ad4645b 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -669,6 +669,9 @@
self: not_implemented("_pdist_backward")
pdist: not_implemented("_pdist_backward")
+- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+ x1, x2: _euclidean_dist_backward(grad, x1, x2, result)
+
- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
x2: _cdist_backward(grad.transpose(-1, -2).contiguous(), x2, x1, p, result.transpose(-1, -2).contiguous())
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 7e17832..7e128e1 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -90,6 +90,15 @@
return size;
}
+std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
+ // handle case at 0 where we return a subgradient containing 0
+ Tensor ratio = grad / res;
+ ratio.masked_fill_(res == 0, 0);
+ return std::tuple<Tensor, Tensor>{
+ x1 * ratio.sum(-1, true) - ratio.matmul(x2),
+ x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)};
+}
+
Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;