blob: 5eecbedd93e7bcc54eba16a85b43353ce68f092b [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <iostream>
#include <ATen/Operators.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
namespace at { namespace functorch {
static bool any_has_value(ArrayRef<optional<int64_t>> bdims) {
for (const auto& bdim : bdims) {
if (bdim.has_value()) {
return true;
}
}
return false;
}
static int64_t get_num_leading_nones(ArrayRef<optional<Tensor>> indices) {
int64_t result = 0;
for (const auto& idx : indices) {
if (!idx.has_value() || !idx->defined()) {
result++;
} else {
return result;
}
}
return result;
}
static int64_t get_max_index_logical_dim(
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims) {
int64_t max_logical_dim = -1;
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
TORCH_INTERNAL_ASSERT(indices.size() > 0);
for (const auto i : c10::irange(0, indices.size())) {
const auto& maybe_tensor = indices[i];
if (!maybe_tensor.has_value() || !maybe_tensor->defined()) {
continue;
}
auto logical_dim = rankWithoutBatchDim(maybe_tensor.value(), indices_bdims[i]);
max_logical_dim = std::max(logical_dim, max_logical_dim);
}
return max_logical_dim;
}
std::vector<optional<Tensor>> batchIndices(
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
int64_t batch_size,
optional<int64_t> self_bdim,
optional<int64_t> values_bdim = nullopt) {
// There are 3 main cases:
// 1. self is batched, indices/values are not batched
// In this case, we just need to augment indices with a None at the front to
// basically broadcast the indexing across the batch dimension of self.
//
// 2. self is not batched, some indices are batched.
// In this case, we don't need to do anything - indices will automatically
// broadcast to work with the unbatched self.
//
// 3. self is batched, some indices are batched.
// In this case, we simply need to add an arange that indexes along the first
// dimension (i.e. the batch dimension). We also need to make sure this
// broadcasts with the rest of the indices.
//
// In all three cases, depending on if advanced indices are adjacent we will
// have to permute the output.
// See NOTE: [advanced indexing (index.Tensor) batch rule] for more details
//
// There is one more case worth mentioning - boolean tensor indices. If we
// have "batched" boolean tensor indices, that is unrepresentable, as each
// batch would result in a tensor with different values.
std::vector<optional<Tensor>> indices_;
int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
bool indices_batched = any_has_value(indices_bdims);
for (size_t i = 0; i < indices.size(); i++) {
auto index = indices[i];
if (index.has_value() && index->numel() != 0) {
const auto idx_bdim = indices_bdims[i];
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
}
} else {
indices_.push_back(index);
}
}
auto maxIndexDim = maxLogicalRank;
if (indices_batched || values_bdim.has_value()) {
maxIndexDim += 1;
}
if (!indices_batched && self_bdim.has_value()) {
indices_.insert(indices_.begin(), nullopt);
} else if (indices_batched && !self_bdim.has_value()) {
// do nothing
} else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
auto arange_index = at::arange(0, batch_size);
while (arange_index.dim() < maxIndexDim) {
arange_index = arange_index.unsqueeze(-1);
}
// TODO: this is O(N)
indices_.insert(indices_.begin(), arange_index);
}
return indices_;
}
// Define an "advanced index" to be a selection object that is
// a non-trivial Tensor (i.e. it does not represent :).
static bool is_advanced_index(const optional<Tensor>& idx) {
if (!idx.has_value()) {
return false;
}
if (!idx->defined()) {
return false;
}
return true;
}
// See NOTE: [advanced indices adjacent] for definition
static bool are_advanced_indices_adjacent(ArrayRef<optional<Tensor>> indices) {
int64_t num_advanced_indices_regions = 0;
bool in_advanced_indices_region = false;
for (const auto& idx : indices) {
if (!in_advanced_indices_region && is_advanced_index(idx)) {
num_advanced_indices_regions++;
in_advanced_indices_region = true;
continue;
}
if (in_advanced_indices_region && !is_advanced_index(idx)) {
in_advanced_indices_region = false;
continue;
}
}
return num_advanced_indices_regions <= 1;
}
// Given a Tensor[B, <first_region>, <second_region>, ...]
// Swaps the regions to produce Tensor[B, <second_region>, <first_region>, ...]
//
// Concretely speaking, given
// - tensor: Tensor[B, 2, 3, 4, 5, 6, 7, 8]
// - first_region_size: 2
// - second_region_size: 3
// Produces:
// - result: Tensor[B, 4, 5, 6, 2, 3, 7, 8]
// ------- ----
// region2 region1
static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int64_t second_region_size) {
VmapDimVector permutation(tensor.dim(), 0);
std::iota(permutation.begin(), permutation.end(), 0);
std::rotate(
permutation.begin() + 1,
permutation.begin() + 1 + first_region_size,
permutation.begin() + 1 + first_region_size + second_region_size);
return tensor.permute(permutation);
}
std::tuple<Tensor,optional<int64_t>> index_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims) {
// NOTE: [advanced indexing (index.Tensor) batch rule]
//
// This is a three step procedure:
// 1. batch `indices`. Depends on self_bdim and indices_bdim.
// 2. call at::index
// 3. (maybe) reorder the dimensions in the result.
// Why is step 3 necessary? Let's take a detour first.
//
// NOTE: [advanced indices adjacent]
// Definition: In a list of optional<Tensor> indices,
// we say that "advanced indices are adjacent" if ALL advanced indices are
// not separated by a None (slice).
//
// So, for example,
// [:, :, (0, 1), (0, 1), :] -> True
// [:, (0, 1), :, (0, 1), :] -> False, the advanced indices are separated by a slice
//
// See https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
// for more details.
//
// NOTE: [Why is step 3 necessary?]
//
// In the original self[*indices] expression,
// depending on whether or not the "advanced indices inside `indices` are
// adjacent", something different happens.
//
// For example:
// - self: Tensor[4, 5, 6, 7]
// - indices: [:, (0, 1), (0, 1), :] (advanced indices are adjacent)
// - self[*indices]: Tensor[4, 2, 7]
// If advanced indices are adjacent, you get the output you would expect.
// (0, 1), (0, 1) says "please index these two dimensions at (0, 0) and (1, 1)
// to produce two elements".
//
// If advanced indices are not adjacent, it is ambiguous to where the new
// dimension of size 2 should go. The numpy spec says it should go at the very
// front of the Tensor.
//
// - self: Tensor[4, 5, 6, 7]
// - indices: [:, (0, 1), :, (0, 1)] (advanced indices not adjacent)
// - self[*indices]: Tensor[2, 4, 6]
//
// Now, this leads to some weird interactions with vmap.
// The indices might originally have adjacent advanced indices, but after
// batching them with "batchIndices", they may no longer be adjacent!
// - indices: [:, (0, 1), (0, 1)]
// - batched_indices (for example): [(0, 1), :, (0, 1), (0, 1)]
// This leads to the dimension of size 2 appearing somewhere else.
//
// There are a couple of different cases that we walk through in the code below.
//
// Background reading for why we care about if the advanced indices are adjacent:
// https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
auto self_ = moveBatchDimToFront(self, self_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices);
// Step 1
const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
auto num_leading_nones = get_num_leading_nones(indices);
auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);
// Step 2
auto res = at::index(self_, List<optional<Tensor>>(batched_indices));
// Step 3: There are three cases (these match the cases outlined in batchIndices)
bool self_batched = self_bdim.has_value();
bool indices_batched = any_has_value(indices_bdims);
TORCH_INTERNAL_ASSERT(self_batched || indices_batched, "Requires at least one batched to get here");
// Case 1
if (self_batched && !indices_batched) {
if (advanced_indices_are_adjacent) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [:, Tensor[2, 2], Tensor[2, 2], :]
// batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
// res: Tensor[B, 5, 2, 2, 8]
return std::make_tuple(res, 0);
} else {
// self: Tensor[B, 5, 6, 7]
// indices: [Tensor[2, 2], :, Tensor[2, 2]]
// batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
// res: Tensor[2, 2, B, 6]
return std::make_tuple(res, max_index_dim);
}
}
// Case 2
if (!self_batched && indices_batched) {
if (advanced_indices_are_adjacent) {
// self: Tensor[5, 6, 7, 8]
// indices: [:, :, Tensor[B, 2, 2], Tensor[2, 2]]
// batched_indices: indices (no change)
// res: Tensor[5, 6, B, 2, 2]
return std::make_tuple(res, num_leading_nones);
} else {
// self: Tensor[5, 6, 7, 8, 9]
// indices: [:, :, Tensor[B, 2, 2], :, Tensor[2, 2]]
// batched_indices: indices (no change)
// res: Tensor[B, 2, 2, 5, 6, 8]
return std::make_tuple(res, 0);
}
}
// Case 3: self_batched and indices_batched
TORCH_INTERNAL_ASSERT(self_batched && indices_batched);
if (!advanced_indices_are_adjacent) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
// batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
// res: Tensor[B, 2, 2, 5, 7]
return std::make_tuple(res, 0);
}
// In other words, in batched_indices, advanced indices are adjacent
if (num_leading_nones == 0) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
// batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
// res: Tensor[B, 2, 2, 7, 8]
return std::make_tuple(res, 0);
}
// This is the tricky case. In indices, advanced indices are adjacent.
// In batched_indices, advanced indices are no longer adjacent
//
// self: Tensor[B, 5, 6, 7, 8, 9]
// indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
// batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
// res: Tensor[B, 2, 3, 5, 6, 9]
// expected: Tensor[B, 5, 6, 2, 3, 9]
//
// The resolution is to move dims around until we get the right shape.
// The result is set up as [B, <maxIndexDim>, <leading_nones>, ...]
// we just have to move the <leading_nones> to before the <maxIndexDim> to produce
// [B, <leading_nones>, <maxIndexDim>, ...]
return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
}
// plumbing done since we don't support List<optional<Tensor>> in codegen
Tensor index_plumbing(const Tensor & self, const List<optional<Tensor>> & indices
) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) {
return at::index(self, indices);
}
Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
for (const auto&& indRef : indices) {
optional<Tensor> ind = indRef;
optional<Tensor> index;
optional<int64_t> index_bdim;
if (ind.has_value()) {
std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
}
indices_value.push_back(index);
indices_bdims.push_back(index_bdim);
}
auto results = index_batch_rule(self_value, self_bdim, indices_value, indices_bdims);
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}
namespace {
// Code is mostly duplicated from
// https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
// /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312
VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list)
{
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
IntArrayRef replacement_shape;
for (const auto dim : c10::irange(indices_list.size())) {
if (!indices_list[dim].defined()) {
if (dims_indexed == 0) {
dims_before++;
} else {
dims_after++;
}
} else {
dims_indexed++;
replacement_shape = indices_list[dim].sizes();
}
}
// Replace indexed dimensions in src with stride 0 and the size of the result tensor.
// The offset in these dimensions is computed by the kernel using the index tensor's
// values and the stride of src. The new shape is not meaningful. It's used to make
// the shape compatible with the result tensor.
auto shape = VmapDimVector(src.sizes());
int64_t end = dims_before + dims_indexed;
shape.erase(shape.begin() + dims_before, shape.begin() + end);
shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
return shape;
}
// Code is mostly duplicated from
// https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
// /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L379-L405
VmapDimVector get_indexed_shape(Tensor self, const torch::List<c10::optional<at::Tensor>> &orig)
{
at::native::checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = at::native::expandTensors(self, orig);
// next broadcast all index tensors together
try {
indices = at::expand_outplace(indices);
} catch (std::exception &e) {
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
" with shapes ");
}
// add missing null Tensors so that it matches self.dim()
while (indices.size() < static_cast<size_t>(self.dim())) {
indices.emplace_back();
}
// if the non-null indices are not all adjacent, transpose self and indices
// together so that they're adjacent at the front
if (!at::native::hasContiguousSubspace(indices)) {
std::tie(self, indices) = at::native::transposeToFront(self, indices);
}
return compute_indexed_shape(self, indices);
}
std::tuple<Tensor, std::vector<optional<Tensor>>, Tensor>
index_put_batch_rule_helper(const Tensor &self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
const Tensor &values,
optional<int64_t> values_bdim,
optional<int64_t> opt_batch_size = {}) {
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor values_ = moveBatchDimToFront(values, values_bdim);
// for inplace variants `index_put_` and `_index_put_impl_` we find the batch_size
// here while for `index_put` does it outside of this function.
const auto batch_size = opt_batch_size ? opt_batch_size.value() : self_.size(0);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
values_ = ensure_has_bdim(values_, values_bdim.has_value(), batch_size);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
// we've already made sure that self has bdim at 0.
const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
auto indexed_shape = get_indexed_shape(self_, List<optional<Tensor>>(indices_));
// handle broadcasting support for values
// Eg. Given `indexed_shape.size()` is 5 and
// shape of `values` is (N, 2, 3), then following block
// will reshape `values` to (N, 1, 1, 2, 3).
if ( (int64_t) indexed_shape.size() > values_.dim()) {
auto values_sizes = values_.sizes();
// number of unit dims (for broadcasting value to indexed_shape)
auto n_unit_dims = indexed_shape.size() - values_sizes.size();
VmapDimVector new_values_shape(values_sizes.size() + n_unit_dims);
// add the batch-dim
new_values_shape[0] = batch_size;
// insert the unit dims for broadcasting.
for (const auto idx : c10::irange(n_unit_dims)) {
// since batch-dim is already be filled.
new_values_shape[idx + 1] = 1;
}
for (const auto idx: c10::irange(1, values_sizes.size())) {
// since batch and unit dims are already be filled.
new_values_shape[idx + n_unit_dims] = values_sizes[idx];
}
values_ = values_.view(new_values_shape);
}
return std::make_tuple(self_, indices_, values_);
}
auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
const List<optional<Tensor>> &indices,
const Tensor &values, int64_t cur_level)
{
Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
for (const auto &&indRef : indices)
{
optional<Tensor> ind = indRef;
optional<Tensor> index;
optional<int64_t> index_bdim;
if (ind.has_value())
{
std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
}
indices_value.push_back(index);
indices_bdims.push_back(index_bdim);
}
Tensor values_value;
optional<int64_t> values_bdim;
std::tie(values_value, values_bdim) = unwrapTensorAtLevel(values, cur_level);
return std::make_tuple(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim);
}
} // namespace
void index_put__batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
const Tensor& values,
optional<int64_t> values_bdim,
bool accumulate) {
if (!self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_put_");
}
Tensor self_, values_;
std::vector<optional<Tensor>> indices_;
std::tie(self_, indices_, values_) = index_put_batch_rule_helper(
self, self_bdim, indices, indices_bdims, values, values_bdim);
at::index_put_(self_, List<optional<Tensor>>(indices_), values_, accumulate);
}
// plumbing done since we don't support List<optional<Tensor>> in codegen
Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indices
, const Tensor & values, bool accumulate) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
return self.index_put_(indices, values, accumulate);
}
Tensor self_value, values_value;
optional<int64_t> self_bdim, values_bdim;
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) =
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
return self;
}
void _index_put_impl__batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
const Tensor& values,
optional<int64_t> values_bdim,
bool accumulate,
bool unsafe) {
if (!self_bdim.has_value()) {
vmapIncompatibleInplaceError("_index_put_impl_");
}
Tensor self_, values_;
std::vector<optional<Tensor>> indices_;
std::tie(self_, indices_, values_) = index_put_batch_rule_helper(
self, self_bdim, indices, indices_bdims, values, values_bdim);
at::_index_put_impl_(self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
}
// plumbing done since we don't support List<optional<Tensor>> in codegen
Tensor &_index_put_impl__plumbing(Tensor &self, const List<optional<Tensor>> &indices,
const Tensor &values, bool accumulate, bool unsafe) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
return at::_index_put_impl_(self, indices, values, accumulate, unsafe);
}
Tensor self_value, values_value;
optional<int64_t> self_bdim, values_bdim;
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) =
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
_index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
return self;
}
static Tensor maybe_permute_values(
const Tensor& values,
ArrayRef<optional<Tensor>> orig_indices,
ArrayRef<optional<int64_t>> orig_indices_bdims) {
bool indices_batched = any_has_value(orig_indices_bdims);
bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
auto num_leading_nones = get_num_leading_nones(orig_indices);
auto max_index_dim = get_max_index_logical_dim(orig_indices, orig_indices_bdims);
TORCH_INTERNAL_ASSERT(values.dim() >= num_leading_nones + max_index_dim);
// NB: values has its B dimension at the front
if (!indices_batched) {
if (advanced_indices_are_adjacent) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [:, Tensor[2, 2], Tensor[2, 2], :]
// batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
// required values: Tensor[B, 5, 2, 2, 8]
return values;
}
// self: Tensor[B, 5, 6, 7]
// indices: [Tensor[2, 2], :, Tensor[2, 2]]
// batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
// required values: Tensor[2, 2, B, 6]
return values.movedim(0, max_index_dim);
}
if (!advanced_indices_are_adjacent) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
// batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
// required values: Tensor[B, 2, 2, 5, 7]
return values;
}
// In other words, in batched_indices, advanced indices are adjacent
if (num_leading_nones == 0) {
// self: Tensor[B, 5, 6, 7, 8]
// indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
// batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
// required values: Tensor[B, 2, 2, 7, 8]
return values;
}
// This is the tricky case. In indices, advanced indices are adjacent.
// In batched_indices, advanced indices are no longer adjacent
//
// self: Tensor[B, 5, 6, 7, 8, 9]
// indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
// batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
// required values: Tensor[B, 2, 3, 5, 6, 9]
// actual values: Tensor[B, 5, 6, 2, 3, 9]
//
// The resolution is to move dims around until we get the right shape.
// The values is set up as [B, <leading_nones>, <maxIndexDim>, ...]
// we just have to move the <maxIndexDim> to before the <leading_nones> to produce
// [B, <maxIndexDim>, <leading_nones>, ...]
return swap_regions(values, num_leading_nones, max_index_dim);
}
std::tuple<Tensor,optional<int64_t>> index_put_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
const Tensor& values,
optional<int64_t> values_bdim,
bool accumulate) {
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
// find the batch_size
int64_t batch_size = 0;
if (self_bdim || values_bdim) {
batch_size = get_bdim_size2(self, self_bdim, values, values_bdim);
} else {
// one or more of the indices is batched.
for (size_t i = 0; i < indices.size(); i++) {
if (indices_bdims[i] && indices[i].has_value()) {
batch_size = indices[i].value().size(*indices_bdims[i]);
break;
}
}
}
Tensor self_, values_;
std::vector<optional<Tensor>> indices_;
std::tie(self_, indices_, values_) = index_put_batch_rule_helper(
self, self_bdim, indices, indices_bdims, values, values_bdim, batch_size);
// Why do we need to permute values?
// See NOTE [Advanced indexing (index.Tensor) batch rule] for details,
// but the gist is that index_put effectively does the following:
// - result = self_.clone()
// - result[indices_] = values
// - return result
// Now, the problem is, result[indices_] might return a Tensor whose shape is
// the shape of values, but permuted. This is because the shape of result[indices_]
// depends on if the original indices "have adjacent advanced indices"
// and the batched `indices_` might change the "have adjacent advanced indices" property
values_ = maybe_permute_values(values_, indices, indices_bdims);
auto result = at::index_put(self_, List<optional<Tensor>>(indices_), values_, accumulate);
return std::make_tuple(result, 0);
}
// plumbing done since we don't support List<optional<Tensor>> in codegen
Tensor index_put_plumbing(const Tensor & self, const List<optional<Tensor>> & indices,
const Tensor & values, bool accumulate) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
return self.index_put(indices, values, accumulate);
}
Tensor self_value, values_value;
optional<int64_t> self_bdim, values_bdim;
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) =
unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}
namespace {
template<typename Func, typename ...Args>
std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
Func f,
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Scalar& value, Args... args) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
if (index_logical_rank == 0) {
index_ = index_.unsqueeze(-1);
}
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
auto result = f(self_, physical_dim, index_, value, args...);
// result should have same shape as self
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
template <typename Func, typename ...Args>
inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
Func f,
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& src, optional<int64_t> src_bdim, Args... args) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
auto src_ = moveBatchDimToFront(src, src_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
if (index_logical_rank == 0) {
index_ = index_.unsqueeze(-1);
}
if (src_logical_rank == 0) {
src_ = src_.unsqueeze(-1);
}
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
auto result = f(self_, physical_dim, index_, src_, args...);
// result should have same shape as self
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
} // namespace
std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Scalar& value) {
return scatter_batch_rule(ATEN_FN2(scatter, value),
self, self_bdim, dim, index, index_bdim, value);
}
std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& src, optional<int64_t> src_bdim) {
return scatter_batch_rule(ATEN_FN2(scatter, src),
self, self_bdim, dim, index, index_bdim, src, src_bdim);
}
std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& src, optional<int64_t> src_bdim) {
return scatter_batch_rule(ATEN_FN(scatter_add),
self, self_bdim, dim, index, index_bdim, src, src_bdim);
}
std::tuple<Tensor,optional<int64_t>> scatter_reduce_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& src, optional<int64_t> src_bdim,
const c10::string_view reduce) {
return scatter_batch_rule(ATEN_FN2(scatter, reduce),
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
}
std::tuple<Tensor,optional<int64_t>> scatter_value_reduce_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Scalar& src,
const c10::string_view reduce) {
return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
self, self_bdim, dim, index, index_bdim, src, reduce);
}
std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
bool sparse_grad) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
if (index_logical_rank == 0) {
index_ = index_.unsqueeze(-1);
}
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
auto result = at::gather(self_, physical_dim, index_, sparse_grad);
// result should have same rank as index
if (index_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
std::tuple<Tensor,optional<int64_t>> gather_backward_batch_rule(
const Tensor& grad, optional<int64_t> grad_bdim,
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
bool sparse_grad) {
auto batch_size = get_bdim_size3(grad, grad_bdim, self, self_bdim, index, index_bdim);
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto grad_logical_rank = rankWithoutBatchDim(grad, grad_bdim);
if (grad_logical_rank == 0) {
grad_ = grad_.unsqueeze(-1);
}
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
if (index_logical_rank == 0) {
index_ = index_.unsqueeze(-1);
}
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
auto result = at::gather_backward(grad_, self_, physical_dim, index_, sparse_grad);
// result should has same rank as self
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
namespace {
Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) {
if (index.dim() == 0) {
return index.expand(self_size);
}
// setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
// to reshape index_
auto idx_size = index.size(0); // get non-batch size of index tensor
Tensor index_;
{
VmapDimVector new_index_shape(self_size.size(), 1);
new_index_shape[dim] = idx_size;
index_ = index.view(new_index_shape);
}
// Now apply expand to index_
{
VmapDimVector new_index_shape = {self_size.begin(), self_size.end()};
new_index_shape[dim] = idx_size;
index_ = index_.expand(new_index_shape);
}
return index_;
}
}
Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index)
{
Tensor index_ = index;
if (self.dim() > index.dim()) {
index_ = get_expanded_index(index, self.sizes(), dim);
}
auto result = at::gather(self, dim, index_);
// output of gather has same dimension as `index` while
// output of index_select has same dimension as self
// Eg. t = torch.tensor(1)
// idx = torch.tensor([0])
// torch.index_select(t, 0, idx) # 0-D
// torch.gather(t, 0, idx) # 1-D
if (self.dim() == 0 && result.dim() != 0) {
result = result.squeeze(-1);
}
return result;
}
Tensor index_copy_decomp(
const Tensor &self, int64_t dim,
const Tensor &index, const Tensor &source)
{
Tensor index_ = index;
if (self.dim() > index.dim()) {
index_ = get_expanded_index(index, self.sizes(), dim);
}
return at::scatter(self, dim, index_, source); ;
}
Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step)
{
auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong));
idx = get_expanded_index(idx, self.sizes(), dim);
return at::scatter(self, dim, idx, src);
}
Tensor select_scatter_decomp(
const Tensor &self, const Tensor &source,
int64_t dim, int64_t index)
{
// supports negative index
index = maybe_wrap_dim(index, self.size(dim));
auto index_ = at::scalar_tensor(index, self.options().dtype(kLong));
return at::scatter(self, dim, index_.expand_as(self), source.unsqueeze(dim).expand_as(self));
}
std::tuple<Tensor, optional<int64_t>> diagonal_scatter_batch_rule(
const Tensor &self, c10::optional<int64_t> self_bdim,
const Tensor &src, c10::optional<int64_t> src_bdim,
int64_t offset, int64_t dim1, int64_t dim2)
{
auto self_ = moveBatchDimToFront(self, self_bdim);
auto src_ = moveBatchDimToFront(src, src_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, src, src_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
dim1 = maybe_wrap_dim(dim1, self_logical_rank) + 1;
dim2 = maybe_wrap_dim(dim2, self_logical_rank) + 1;
return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
}
std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& other, optional<int64_t> other_bdim,
const Scalar& alpha) {
if (!index_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
auto other_ = moveBatchDimToFront(other, other_bdim);
if (other_logical_rank == 0) {
other_ = other_.unsqueeze(-1);
}
dim = maybe_wrap_dim(dim, self_logical_rank);
const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size);
auto result = self_.index_add(dim + 1, index, other_, alpha);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
// Index is batched. For-loop and stack is the best thing I can come up with
// right now. We really want generalized index_add kernel in PyTorch
auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& other_slice = other_bdim.has_value() ?
other.select(*other_bdim, i) : other;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));
}
return std::make_tuple(at::stack(results), 0);
}
static std::tuple<Tensor,Tensor> binary_pointwise_align(
const Tensor & self,
optional<int64_t> self_bdim,
const Tensor & mask,
optional<int64_t> mask_bdim) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(mask, mask_bdim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(self, self_bdim);
auto other_ = moveBatchDimToFront(mask, mask_bdim);
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank);
return std::make_tuple(tensor_, other_);
}
std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
const Tensor & self,
optional<int64_t> self_bdim,
const Tensor & mask,
optional<int64_t> mask_bdim,
const Scalar& source) {
auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim);
auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source);
return std::make_tuple(result, 0);
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
m.impl("index_put", index_put_plumbing);
m.impl("_index_put_impl_", _index_put_impl__plumbing);
m.impl("slice_scatter", slice_scatter_decomp);
m.impl("select_scatter", select_scatter_decomp);
m.impl("index_copy", index_copy_decomp);
m.impl("index_select", index_select_decomp);
VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
VMAP_SUPPORT(index_add, index_add_batch_rule);
VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
VMAP_SUPPORT(gather, gather_batch_rule);
VMAP_SUPPORT(gather_backward, gather_backward_batch_rule);
VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule);
VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule);
VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
}
}}