torch.diag bug fix (#1251)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index a20f25d..3a1677b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1077,6 +1077,8 @@
(Resize, (S * S, S), ((S, S, S),)),
(Diag, (), ((S, S),), '2d'),
(Diag, (), ((S,),), '1d'),
+ (Diag, (1,), ((S, S),), '2d_1'),
+ (Diag, (2,), ((S, S),), '2d_2'),
(Tril, (), ((S, S),)),
(Tril, (2,), ((S, S),), 'idx'),
(Triu, (), ((S, S),)),
diff --git a/torch/autograd/_functions/linalg.py b/torch/autograd/_functions/linalg.py
index 4b522be..b0090f4 100644
--- a/torch/autograd/_functions/linalg.py
+++ b/torch/autograd/_functions/linalg.py
@@ -10,10 +10,10 @@
self.diagonal_idx = diagonal_idx
def forward(self, input):
- return input.diag()
+ return input.diag(self.diagonal_idx)
def backward(self, grad_output):
- return grad_output.diag()
+ return grad_output.diag(self.diagonal_idx)
class Tril(Function):