Add vmap support for torch.index_fill (#91364)

Fixes #91177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91364
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
index da1711e..a346e5f 100644
--- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
@@ -1056,6 +1056,164 @@
   return std::make_tuple(result, 0);
 }
 
+std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
+    Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Scalar & value,
+    const bool inplace) {
+  const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
+  const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
+  Tensor self_ = moveBatchDimToFront(self, self_bdim);
+  Tensor index_ = moveBatchDimToFront(index, index_bdim);
+  dim = maybe_wrap_dim(dim, self_logical_rank);
+
+  if (inplace && !self_bdim.has_value()) {
+    vmapIncompatibleInplaceError("index_fill_");
+  }
+
+  if (!index_bdim) {
+    if (self_logical_rank == 0){
+      self_.unsqueeze_(-1);
+    }
+    self_.index_fill_(dim + 1, index_, value);
+    if (self_logical_rank == 0) {
+      self_.squeeze_(-1);
+    }
+    return std::make_tuple(self_, 0);
+  }
+
+  auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
+  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
+  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
+
+  if (inplace) {
+    // Do for-loop for in-place because we cannot reshape
+    // `self_` having an incompatible stride without copying
+    for (const auto i : c10::irange(0, batch_size)) {
+      const auto& self_slice = self_.select(0, i);
+      const auto& index_slice = index_.select(0, i);
+      self_slice.index_fill_(
+        dim,
+        index_slice,
+        value
+      );
+    }
+    return std::make_tuple(self_, 0);
+  }
+
+  self_ = self_bdim.has_value() ? self_ : self_.clone();
+
+  if (self_logical_rank != 0){
+    auto index_offset = at::arange(
+      batch_size,
+      at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
+    );
+    if (index_logical_rank == 0){
+      index_ = index_.unsqueeze(-1);
+    }
+    index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
+    index_ = reshape_dim_into(0, 0, index_);
+    self_ = reshape_dim_into(0, dim, self_);
+    self_.index_fill_(dim, index_, value);
+    self_ = reshape_dim_outof(dim, batch_size, self_);
+    return std::make_tuple(self_, dim);
+  }
+
+  // If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
+  if (index_logical_rank != 0){
+    index_ = reshape_dim_into(0, 0, index_);
+  }
+  self_.unsqueeze_(-1);
+  self_.index_fill_(dim + 1, index_, value);
+  self_.squeeze_(-1);
+
+  return std::make_tuple(self_, 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
+    Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Tensor & value, optional<int64_t> value_bdim,
+    const bool inplace) {
+  const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
+  Tensor self_ = moveBatchDimToFront(self, self_bdim);
+  Tensor index_ = moveBatchDimToFront(index, index_bdim);
+  Tensor value_ = moveBatchDimToFront(value, value_bdim);
+  dim = maybe_wrap_dim(dim, self_logical_rank);
+
+  if (inplace && !self_bdim.has_value()) {
+    vmapIncompatibleInplaceError("index_fill_");
+  }
+
+  if (!index_bdim && !value_bdim) {
+    if (self_logical_rank == 0){
+      self_.unsqueeze_(-1);
+    }
+    self_.index_fill_(dim + 1, index_, value);
+    if (self_logical_rank == 0) {
+      self_.squeeze_(-1);
+    }
+    return std::make_tuple(self_, 0);
+  }
+
+  auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
+  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
+  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
+  value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
+
+  self_ = self_bdim.has_value() ? self_ : self_.clone();
+
+  for (const auto i : c10::irange(0, batch_size)) {
+    const auto& self_slice = self_.select(0, i);
+    const auto& index_slice = index_.select(0, i);
+    const auto& value_slice = value_.select(0, i);
+    self_slice.index_fill_(
+      dim,
+      index_slice,
+      value_slice
+    );
+  }
+
+  return std::make_tuple(self_, 0);
+}
+
+void index_fill__int_scalar_batch_rule(
+    Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Scalar & value) {
+  index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
+}
+
+void index_fill__int_tensor_batch_rule(
+    Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Tensor & value, optional<int64_t> value_bdim) {
+  index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
+    const Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Scalar & value) {
+  auto self_ = self.clone(at::MemoryFormat::Preserve);
+  return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
+    const Tensor & self, optional<int64_t> self_bdim,
+    int64_t dim,
+    const Tensor & index, optional<int64_t> index_bdim,
+    const Tensor & value, optional<int64_t> value_bdim) {
+  auto self_ = self.clone(at::MemoryFormat::Preserve);
+  return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
+}
+
+
 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
   m.impl("index.Tensor", index_plumbing);
   m.impl("index_put_", index_put__plumbing);
@@ -1066,6 +1224,10 @@
   m.impl("index_copy", index_copy_decomp);
   m.impl("index_select", index_select_decomp);
   VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
+  VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
+  VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
+  VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
+  VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
   VMAP_SUPPORT(index_add, index_add_batch_rule);
   VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
   VMAP_SUPPORT(gather, gather_batch_rule);
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 633c1ab..5265490 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1043,7 +1043,6 @@
         xfail('fill'),
         skip('masked.mean'),  # ???
         xfail('masked_scatter'),
-        xfail('index_fill'),
         xfail('put'),
         xfail('take'),
         xfail('nn.functional.max_pool3d'),
@@ -1114,8 +1113,6 @@
         xfail('fill'),
         xfail('narrow'),  # Batching rule not implemented for `narrow.Tensor` (and view op)
         xfail('special.log_ndtr'),
-        xfail('index_copy'),
-        xfail('index_fill'),
         xfail('linalg.householder_product'),
         xfail('lu'),
         xfail('lu_solve'),
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 404e7c8..632b407 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3613,7 +3613,6 @@
         xfail('native_batch_norm'),
         xfail('_native_batch_norm_legit'),
         xfail('histogram'),
-        xfail('index_fill'),
         xfail('scatter_reduce', 'sum'),
         xfail('scatter_reduce', 'mean'),
         xfail('scatter_reduce', 'amax'),
@@ -3861,11 +3860,83 @@
         # There's no OpInfo for this
         def test():
             B = 2
-            x = torch.randn(2, 5, 5, device=device)
+            x = torch.randn(B, 5, 5, device=device)
             self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))
 
         check_vmap_fallback(self, test, torch.slogdet)
 
+    def test_index_fill(self, device):
+        # There's no OpInfo for these tests
+
+        B = 2
+
+        def test1():
+            # negative dim
+            x = torch.randn(B, 5, 5, device=device)
+            dim = -2
+            index = torch.tensor([[2, 3], [0, 4]], device=device)
+            value = 5.0
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+        def test2():
+            # self batched, self logical rank 1, index logical rank 1
+            x = torch.zeros(B, 3, device=device)
+            dim = 0
+            index = torch.tensor([[0], [1]], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+        def test3():
+            # self batched, self logical rank 1, index logical rank 0
+            x = torch.zeros(B, 3, device=device)
+            dim = 0
+            index = torch.tensor([0, 1], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+        def test4():
+            # self not batched, self logical rank 0, index logical rank 1
+            x = torch.zeros([], device=device)
+            dim = 0
+            index = torch.tensor([[0], [0]], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+        def test5():
+            # self not batched, self logical rank 0, index logical rank 0
+            x = torch.zeros([], device=device)
+            dim = 0
+            index = torch.tensor([0, 0], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+        def test6():
+            # self not batched, self logical rank 0, index logical rank 1
+            x = torch.zeros(3, device=device)
+            dim = 0
+            index = torch.tensor([[0], [1]], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+        def test7():
+            # self not batched, self logical rank 0, index logical rank 0
+            x = torch.zeros(3, device=device)
+            dim = 0
+            index = torch.tensor([0, 1], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+        def test8():
+            # self batched, self logical rank > 1, index logical rank 0
+            x = torch.zeros(B, 3, 3, device=device)
+            dim = 0
+            index = torch.tensor([0, 1], device=device)
+            value = 1
+            self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+        for test in (test1, test2, test3, test4, test5, test6, test7, test8):
+            check_vmap_fallback(self, test, torch.index_fill)
+
     def test_fill__Tensor(self, device):
         # There's no OpInfo for fill_.Tensor, so here's an extra test for it.
         def test():