Fix select backward when wrap dim (#9033)
Summary:
Previous backward was broken when `index=-1` because slicing `[-1:0]` gives empty tensor/list/array.
Added a test.
cc goldsborough
Closes https://github.com/pytorch/pytorch/pull/9033
Differential Revision: D8694300
Pulled By: SsnL
fbshipit-source-id: 8377b043896f8d0b1da173cc0077ace0bea5e862
diff --git a/test/test_autograd.py b/test/test_autograd.py
index b74d7d0..6e8a746 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -3045,6 +3045,7 @@
('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim'),
('permute', (), (dont_convert(()),), 'scalar'),
('select', (S, S, S), (1, 2), 'dim', [0]),
+ ('select', (S, S, S), (1, -1), 'wrap_dim', [0]),
('select', (S,), (0, 2), '1d'),
('narrow', (S, S, S), (1, 2, 2), 'dim', [0]),
('narrow', (S, S, S), (1, 0, 0), 'empty_dim', [0], [skipIfNoZeroSize]),
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 9b1fb04..fd2ecc6 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -603,7 +603,7 @@
src: grad.gather(dim, index)
- name: select(Tensor self, int64_t dim, int64_t index)
- self: slice_backward(grad.unsqueeze(dim), self.sizes(), dim, index, index + 1, 1)
+ self: select_backward(grad, self.sizes(), dim, index)
- name: sigmoid(Tensor self)
self: _sigmoid_backward(grad, result)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 12b29fa..016e533 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -521,11 +521,17 @@
}
Tensor slice_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
- auto grad_input = at::zeros(input_sizes, grad.type());
+ auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.slice(dim, start, end, step).copy_(grad);
return grad_input;
}
+Tensor select_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t index) {
+ auto grad_input = at::zeros(input_sizes, grad.options());
+ grad_input.select(dim, index).copy_(grad);
+ return grad_input;
+}
+
Tensor trace_backward(const Tensor & grad, IntList sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");