Add vmap support for torch.index_fill (#91364)
Fixes #91177
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91364
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
index da1711e..a346e5f 100644
--- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp
@@ -1056,6 +1056,164 @@
return std::make_tuple(result, 0);
}
+std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
+ Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Scalar & value,
+ const bool inplace) {
+ const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
+ const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
+ Tensor self_ = moveBatchDimToFront(self, self_bdim);
+ Tensor index_ = moveBatchDimToFront(index, index_bdim);
+ dim = maybe_wrap_dim(dim, self_logical_rank);
+
+ if (inplace && !self_bdim.has_value()) {
+ vmapIncompatibleInplaceError("index_fill_");
+ }
+
+ if (!index_bdim) {
+ if (self_logical_rank == 0){
+ self_.unsqueeze_(-1);
+ }
+ self_.index_fill_(dim + 1, index_, value);
+ if (self_logical_rank == 0) {
+ self_.squeeze_(-1);
+ }
+ return std::make_tuple(self_, 0);
+ }
+
+ auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
+ self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
+ index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
+
+ if (inplace) {
+ // Do for-loop for in-place because we cannot reshape
+ // `self_` having an incompatible stride without copying
+ for (const auto i : c10::irange(0, batch_size)) {
+ const auto& self_slice = self_.select(0, i);
+ const auto& index_slice = index_.select(0, i);
+ self_slice.index_fill_(
+ dim,
+ index_slice,
+ value
+ );
+ }
+ return std::make_tuple(self_, 0);
+ }
+
+ self_ = self_bdim.has_value() ? self_ : self_.clone();
+
+ if (self_logical_rank != 0){
+ auto index_offset = at::arange(
+ batch_size,
+ at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
+ );
+ if (index_logical_rank == 0){
+ index_ = index_.unsqueeze(-1);
+ }
+ index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
+ index_ = reshape_dim_into(0, 0, index_);
+ self_ = reshape_dim_into(0, dim, self_);
+ self_.index_fill_(dim, index_, value);
+ self_ = reshape_dim_outof(dim, batch_size, self_);
+ return std::make_tuple(self_, dim);
+ }
+
+ // If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
+ if (index_logical_rank != 0){
+ index_ = reshape_dim_into(0, 0, index_);
+ }
+ self_.unsqueeze_(-1);
+ self_.index_fill_(dim + 1, index_, value);
+ self_.squeeze_(-1);
+
+ return std::make_tuple(self_, 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
+ Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Tensor & value, optional<int64_t> value_bdim,
+ const bool inplace) {
+ const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
+ Tensor self_ = moveBatchDimToFront(self, self_bdim);
+ Tensor index_ = moveBatchDimToFront(index, index_bdim);
+ Tensor value_ = moveBatchDimToFront(value, value_bdim);
+ dim = maybe_wrap_dim(dim, self_logical_rank);
+
+ if (inplace && !self_bdim.has_value()) {
+ vmapIncompatibleInplaceError("index_fill_");
+ }
+
+ if (!index_bdim && !value_bdim) {
+ if (self_logical_rank == 0){
+ self_.unsqueeze_(-1);
+ }
+ self_.index_fill_(dim + 1, index_, value);
+ if (self_logical_rank == 0) {
+ self_.squeeze_(-1);
+ }
+ return std::make_tuple(self_, 0);
+ }
+
+ auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
+ self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
+ index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
+ value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
+
+ self_ = self_bdim.has_value() ? self_ : self_.clone();
+
+ for (const auto i : c10::irange(0, batch_size)) {
+ const auto& self_slice = self_.select(0, i);
+ const auto& index_slice = index_.select(0, i);
+ const auto& value_slice = value_.select(0, i);
+ self_slice.index_fill_(
+ dim,
+ index_slice,
+ value_slice
+ );
+ }
+
+ return std::make_tuple(self_, 0);
+}
+
+void index_fill__int_scalar_batch_rule(
+ Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Scalar & value) {
+ index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
+}
+
+void index_fill__int_tensor_batch_rule(
+ Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Tensor & value, optional<int64_t> value_bdim) {
+ index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
+ const Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Scalar & value) {
+ auto self_ = self.clone(at::MemoryFormat::Preserve);
+ return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
+}
+
+std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
+ const Tensor & self, optional<int64_t> self_bdim,
+ int64_t dim,
+ const Tensor & index, optional<int64_t> index_bdim,
+ const Tensor & value, optional<int64_t> value_bdim) {
+ auto self_ = self.clone(at::MemoryFormat::Preserve);
+ return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
+}
+
+
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
@@ -1066,6 +1224,10 @@
m.impl("index_copy", index_copy_decomp);
m.impl("index_select", index_select_decomp);
VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
+ VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
+ VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
+ VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
+ VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
VMAP_SUPPORT(index_add, index_add_batch_rule);
VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
VMAP_SUPPORT(gather, gather_batch_rule);
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 633c1ab..5265490 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1043,7 +1043,6 @@
xfail('fill'),
skip('masked.mean'), # ???
xfail('masked_scatter'),
- xfail('index_fill'),
xfail('put'),
xfail('take'),
xfail('nn.functional.max_pool3d'),
@@ -1114,8 +1113,6 @@
xfail('fill'),
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
xfail('special.log_ndtr'),
- xfail('index_copy'),
- xfail('index_fill'),
xfail('linalg.householder_product'),
xfail('lu'),
xfail('lu_solve'),
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 404e7c8..632b407 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3613,7 +3613,6 @@
xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'),
xfail('histogram'),
- xfail('index_fill'),
xfail('scatter_reduce', 'sum'),
xfail('scatter_reduce', 'mean'),
xfail('scatter_reduce', 'amax'),
@@ -3861,11 +3860,83 @@
# There's no OpInfo for this
def test():
B = 2
- x = torch.randn(2, 5, 5, device=device)
+ x = torch.randn(B, 5, 5, device=device)
self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))
check_vmap_fallback(self, test, torch.slogdet)
+ def test_index_fill(self, device):
+ # There's no OpInfo for these tests
+
+ B = 2
+
+ def test1():
+ # negative dim
+ x = torch.randn(B, 5, 5, device=device)
+ dim = -2
+ index = torch.tensor([[2, 3], [0, 4]], device=device)
+ value = 5.0
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+ def test2():
+ # self batched, self logical rank 1, index logical rank 1
+ x = torch.zeros(B, 3, device=device)
+ dim = 0
+ index = torch.tensor([[0], [1]], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+ def test3():
+ # self batched, self logical rank 1, index logical rank 0
+ x = torch.zeros(B, 3, device=device)
+ dim = 0
+ index = torch.tensor([0, 1], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+ def test4():
+ # self not batched, self logical rank 0, index logical rank 1
+ x = torch.zeros([], device=device)
+ dim = 0
+ index = torch.tensor([[0], [0]], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+ def test5():
+ # self not batched, self logical rank 0, index logical rank 0
+ x = torch.zeros([], device=device)
+ dim = 0
+ index = torch.tensor([0, 0], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+ def test6():
+ # self not batched, self logical rank 0, index logical rank 1
+ x = torch.zeros(3, device=device)
+ dim = 0
+ index = torch.tensor([[0], [1]], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+ def test7():
+ # self not batched, self logical rank 0, index logical rank 0
+ x = torch.zeros(3, device=device)
+ dim = 0
+ index = torch.tensor([0, 1], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
+
+ def test8():
+ # self batched, self logical rank > 1, index logical rank 0
+ x = torch.zeros(B, 3, 3, device=device)
+ dim = 0
+ index = torch.tensor([0, 1], device=device)
+ value = 1
+ self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
+
+ for test in (test1, test2, test3, test4, test5, test6, test7, test8):
+ check_vmap_fallback(self, test, torch.index_fill)
+
def test_fill__Tensor(self, device):
# There's no OpInfo for fill_.Tensor, so here's an extra test for it.
def test():