add fill_diagonal function (#21892)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/21796
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21892

Differential Revision: D16164678

Pulled By: colesbury

fbshipit-source-id: 85df8ae9b7a6a91b6023fe7295b3a8124e4526ea
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index eafc5c1..95c02a9 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -406,6 +406,7 @@
   Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const;
   Tensor diagflat(int64_t offset=0) const;
   Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const;
+  Tensor & fill_diagonal_(Scalar fill_value, bool wrap=false);
   Tensor div(const Tensor & other) const;
   Tensor & div_(const Tensor & other);
   Tensor div(Scalar other) const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index d07c5f2..53cb9de 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -281,6 +281,10 @@
     static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)");
     return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, offset, dim1, dim2);
 }
+inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) {
+    static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)");
+    return table->getOp<Tensor & (Tensor &, Scalar, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, fill_value, wrap);
+}
 inline Tensor Tensor::div(const Tensor & other) const {
     static auto table = globalATenDispatch().getOpTable("aten::div(Tensor self, Tensor other) -> Tensor");
     return table->getOp<Tensor (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, other);
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 97d632e..3dcb545 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -284,6 +284,7 @@
 _(aten, diag_embed) \
 _(aten, diagflat) \
 _(aten, diagonal) \
+_(aten, fill_diagonal_) \
 _(aten, digamma) \
 _(aten, dim) \
 _(aten, dist) \
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index a8edf32..638910d 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -287,6 +287,55 @@
   return native::full(self.sizes(), fill_value, options);
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Tensor& fill_diagonal_(Tensor& self, Scalar fill_value, bool wrap) {
+  int64_t nDims = self.dim();
+  TORCH_CHECK(nDims >= 2, "dimensions must larger than 1");
+
+  int64_t height = self.size(0);
+  int64_t width = self.size(1);
+
+  if (nDims > 2) {
+    int64_t dim1 = height;
+    for (int64_t i = 1; i < nDims; i++) {
+      if (self.size(i) != dim1) {
+        AT_ERROR("all dimensions of input must be of equal length");
+      }
+    }
+  }
+
+  int64_t storage_offset = self.storage_offset();
+  std::vector<int64_t> sizes;
+  std::vector<int64_t> strides;
+  int64_t size = std::min(height, width);
+
+  int64_t stride = 0;
+  for (int64_t i = 0; i < nDims; i++) {
+    stride += self.stride(i);
+  }
+  strides.push_back(stride);
+  sizes.push_back(size);
+
+  auto main_diag = self.as_strided(sizes, strides, storage_offset);
+  main_diag.fill_(fill_value);
+
+  if (wrap && nDims == 2 && height > width + 1) {
+    std::vector<int64_t> wrap_sizes;
+
+    int64_t step = width + 1;
+    int64_t wrap_size = ((self.numel() + step - 1) / step) - size;
+    wrap_sizes.push_back(wrap_size);
+
+    int64_t offset = self.stride(0) * (width + 1);
+
+    auto wrap_diag = self.as_strided(wrap_sizes, strides, storage_offset + offset);
+    wrap_diag.fill_(fill_value);
+  }
+
+  return self;
+}
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 Tensor linspace(
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 07ca340..0cb4307 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -649,6 +649,9 @@
 - func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
   variants: function, method
 
+- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
+  variants: method
+
 - func: div(Tensor self, Tensor other) -> Tensor
   variants: function, method
 
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 5b2fcc3..0f60f8a 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -226,6 +226,7 @@
    .. automethod:: diag_embed
    .. automethod:: diagflat
    .. automethod:: diagonal
+   .. automethod:: fill_diagonal_
    .. automethod:: digamma
    .. automethod:: digamma_
    .. automethod:: dim
diff --git a/test/test_torch.py b/test/test_torch.py
index dc48fd6..9fe39f7 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -12239,6 +12239,46 @@
         c2 = torch.tensor([True, False], dtype=bool)
         self.assertEqual(c1.dtype, c2.dtype)
 
+    def test_fill_diagonal(self):
+        a1 = torch.randn(7, 3)
+        a2 = a1.clone()
+        v = 1
+        for i in range(3):
+            a2[i][i] = v
+        a1.fill_diagonal_(v)
+        self.assertEqual(a1, a2)
+
+        b1 = torch.randn(7, 3)
+        b2 = b1.clone()
+        for i in range(3):
+            b2[i][i] = v
+            b2[i + 4][i] = v
+        b1.fill_diagonal_(v, wrap=True)
+        self.assertEqual(b1, b2)
+
+        c1 = torch.rand(3, 3, 3)
+        c2 = c1.clone()
+        for i in range(3):
+            c2[i][i][i] = v
+        c1.fill_diagonal_(v)
+        self.assertEqual(c1, c2)
+
+        # non-contiguous tensor
+        d1 = torch.rand(3, 3, 3)[:, 1, ...]
+        d2 = d1.clone()
+        for i in range(3):
+            d2[i][i] = v
+        d1.fill_diagonal_(v)
+        self.assertEqual(d1, d2)
+
+        e1 = torch.rand(7, 3, 3)[:, 1, ...]
+        e2 = e1.clone()
+        for i in range(3):
+            e2[i][i] = v
+            e2[i + 4][i] = v
+        e1.fill_diagonal_(v, wrap=True)
+        self.assertEqual(e1, e2)
+
 # Functions to test negative dimension wrapping
 METHOD = 1
 INPLACE_METHOD = 2
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 9018658..d6eeeb5 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -730,6 +730,46 @@
 See :func:`torch.diagonal`
 """)
 
+add_docstr_all('fill_diagonal_',
+               r"""
+fill_diagonal_(fill_value, wrap=False) -> Tensor
+
+Fill the main diagonal of a tensor that has at least 2-dimensions.
+When dims>2, all dimensions of input must be of equal length.
+This function modifies the input tensor in-place, and returns the input tensor.
+
+Arguments:
+    fill_value (Scalar): the fill value
+    wrap (bool): the diagonal 'wrapped' after N columns for tall matrices.
+
+Example::
+
+    >>> a = torch.zeros(3, 3)
+    >>> a.fill_diagonal_(5)
+    tensor([[5., 0., 0.],
+            [0., 5., 0.],
+            [0., 0., 5.]])
+    >>> b = torch.zeros(7, 3)        
+    >>> b.fill_diagonal_(5)
+    tensor([[5., 0., 0.],
+            [0., 5., 0.],
+            [0., 0., 5.],
+            [0., 0., 0.],
+            [0., 0., 0.],
+            [0., 0., 0.],
+            [0., 0., 0.]])
+    >>> c = torch.zeros(7, 3)
+    >>> c.fill_diagonal_(5, wrap=True)
+    tensor([[5., 0., 0.],
+            [0., 5., 0.],
+            [0., 0., 5.],
+            [0., 0., 0.],
+            [5., 0., 0.],
+            [0., 5., 0.],
+            [0., 0., 5.]])
+
+""")
+
 add_docstr_all('digamma',
                r"""
 digamma() -> Tensor