torch.diag bug fix (#1251)

diff --git a/test/test_autograd.py b/test/test_autograd.py
index a20f25d..3a1677b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1077,6 +1077,8 @@
     (Resize, (S * S, S), ((S, S, S),)),
     (Diag, (), ((S, S),), '2d'),
     (Diag, (), ((S,),), '1d'),
+    (Diag, (1,), ((S, S),), '2d_1'),
+    (Diag, (2,), ((S, S),), '2d_2'),
     (Tril, (), ((S, S),)),
     (Tril, (2,), ((S, S),), 'idx'),
     (Triu, (), ((S, S),)),
diff --git a/torch/autograd/_functions/linalg.py b/torch/autograd/_functions/linalg.py
index 4b522be..b0090f4 100644
--- a/torch/autograd/_functions/linalg.py
+++ b/torch/autograd/_functions/linalg.py
@@ -10,10 +10,10 @@
         self.diagonal_idx = diagonal_idx
 
     def forward(self, input):
-        return input.diag()
+        return input.diag(self.diagonal_idx)
 
     def backward(self, grad_output):
-        return grad_output.diag()
+        return grad_output.diag(self.diagonal_idx)
 
 
 class Tril(Function):