Skip manual backward for `cdist` with case `p=2` (#31167)
Summary:
Fixes an issue with `cdist` backward calculation for large inputs for the euclidean case.
The grid size when launching the kernel exceeded the 2^16 limit for the second dimension, resulting in `RuntimeError: CUDA error: invalid configuration argument`
Code to reproduce:
```
h, w, d = 800, 1216, 12
n = 133
A = torch.randn(n, d).cuda()
B = torch.randn(h, w, d).cuda()
A.requires_grad = True
B.requires_grad = True
B = B.reshape(-1, d).contiguous()
dist = torch.cdist(A, B)
loss = dist.sum()
loss.backward()
```
Thanks to tkerola for the bug report, reproduction and suggesting a solution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31167
Differential Revision: D20035605
Pulled By: ngimel
fbshipit-source-id: ae28ba4b549ee07a8bd937bb1de2438dc24eaa17
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index d19073e..4d32b90 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -104,6 +104,28 @@
auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
auto result = [&]() {
NoNamesGuard guard;
+ // This is for pytorch to figure the backward pass itself
+ // when p=2
+ int64_t r1 = x1.size(-2);
+ int64_t r2 = x2.size(-2);
+ int64_t mode = compute_mode.value_or(0);
+ if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
+ return cdist_impl(x1, x2, p, compute_mode);
+ } else {
+ return at::_cdist_forward(x1, x2, p, compute_mode);
+ }
+ }();
+ namedinference::propagate_names_if_nonempty(result, maybe_outnames);
+ return result;
+}
+
+Tensor _cdist_forward(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
+ TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
+ TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
+ TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
+ auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
+ auto result = [&]() {
+ NoNamesGuard guard;
return cdist_impl(x1, x2, p, compute_mode);
}();
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index a095ea2..5b2a250 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2135,6 +2135,9 @@
use_c10_dispatcher: full
- func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
+ supports_named_tensor: True
+
+- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
use_c10_dispatcher: full
supports_named_tensor: True
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 333ea4f..2a7b528 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4136,12 +4136,31 @@
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_cdist", "cdist", f,
True, f_args_variable, f_args_tensor)
+
+ def _test_euclidean_large_cdist(sizex, sizey=None):
+ if sizey is None:
+ sizey = sizex
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ y = torch.randn(sizey, device=device, dtype=torch.float)
+ eps = 1e-6
+ # to avoid extremum
+ x = x - (((x - y) < eps).float() * 2 * eps)
+ x.requires_grad = True
+ y.requires_grad = True
+ f_args_variable = (x, y)
+ dist = torch.cdist(x, y, p=2)
+ # Do a backward pass to check that it is valid for large
+ # matrices
+ loss = dist.sum()
+ loss.backward()
+
_test_cdist_for_size((S, S))
_test_cdist_for_size((S, S, S))
_test_cdist_for_size((3, 5))
_test_cdist_for_size((2, 3, 5))
_test_cdist_for_size((1, 2, 3))
_test_cdist_for_size((1, 1), (S, 1))
+ _test_euclidean_large_cdist((2000, 5))
# NOTE: flaky on ROCm CI
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index cf30af4..c9d1754 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -653,7 +653,7 @@
self: not_implemented("_pdist_backward")
pdist: not_implemented("_pdist_backward")
-- name: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
+- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
x2: _cdist_backward(grad.transpose(-1, -2).contiguous(), x2, x1, p, result.transpose(-1, -2).contiguous())