blob: f0330481c31a920f740a375400cd16648338c23b [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/NamedTensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/align_as_native.h>
#include <ATen/ops/align_tensors_native.h>
#include <ATen/ops/align_to_native.h>
#include <ATen/ops/gather_native.h>
#include <ATen/ops/index_add_native.h>
#include <ATen/ops/index_copy_native.h>
#include <ATen/ops/index_fill.h>
#include <ATen/ops/index_fill_native.h>
#include <ATen/ops/index_select_native.h>
#include <ATen/ops/refine_names_native.h>
#include <ATen/ops/rename_native.h>
#include <ATen/ops/scatter_add_native.h>
#include <ATen/ops/scatter_native.h>
#include <ATen/ops/sort_native.h>
#include <ATen/ops/squeeze.h>
#include <ATen/ops/squeeze_native.h>
#include <ATen/ops/zeros_like_ops.h>
#endif
#include <c10/util/irange.h>
#include <bitset>
namespace at::native {
Tensor& rename_(Tensor& self, optional<DimnameList> names) {
at::internal_set_names_inplace(self, names);
return self;
}
Tensor rename(const Tensor& self, optional<DimnameList> names) {
auto result = self.alias();
at::internal_set_names_inplace(result, names);
return result;
}
static void report_moving_unnamed_dim_error(
DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
if (is_aligning_two_tensors) {
TORCH_CHECK(false,
"Aligning Tensor", names, " and Tensor", other,
" would change the absolute position from the right of an unnamed dimension. ",
"Please name unnamed dimensions to avoid ambiguity.");
} else {
TORCH_CHECK(false,
"Aligning Tensor", names, " to `names` ", other,
" would change the absolute position from the right of an unnamed dimension. ",
"Please name unnamed dimensions to avoid ambiguity.");
}
}
static void report_not_a_subsequence_error(
DimnameList names, DimnameList other, bool is_aligning_two_tensors) {
if (is_aligning_two_tensors) {
auto shorter = names.size() > other.size() ? other : names;
auto longer = names.size() > other.size() ? names : other;
TORCH_CHECK(false,
"Could not align Tensor", shorter, " and Tensor", longer,
" because ", shorter, " is not a subsequence of ", longer, ". ");
} else {
TORCH_CHECK(false,
"Could not align Tensor", names, " to `names` ", other,
" because ", names, " is not a subsequence of `names`.");
}
}
// Let tensor `t` have size `tensor_sizes` and `tensor_names`.
// This helper function computes the resulting size of `t` after aligning it
// to `aligned_names`. Enforces the alignment rules in Note [Alignment rules].
static std::vector<int64_t> aligned_size(
IntArrayRef tensor_sizes,
DimnameList tensor_names,
DimnameList aligned_names,
bool is_aligning_two_tensors) {
std::vector<int64_t> expanded_sizes(aligned_names.size(), 1);
ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
for (; idx >= 0 && dim >= 0; --idx) {
if (tensor_names[dim] != aligned_names[idx]) {
continue;
}
// We've found a None name in `shorter` and `longer`. If their absolute positions
// from the right are not equal, then aligning the two names would require
// changing the absolute position from right of one of the None names,
// violating condition 2 of our [Alignment rules].
//
// For example:
// *, c, a, b
// *, a
// [*, a] is a subsequence of [*, c, a, b], but in order to align them,
// we'd have to move the * to create [*, c: 1, a, b: 1]
if (tensor_names[dim].isWildcard() &&
tensor_sizes.size() - dim != aligned_names.size() - idx) {
report_moving_unnamed_dim_error(
tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
}
expanded_sizes[idx] = tensor_sizes[dim];
--dim;
}
if (dim != -1) {
report_not_a_subsequence_error(
tensor_names, aligned_names, /*is_aligning_two_tensors=*/false);
}
return expanded_sizes;
}
Tensor refine_names(const Tensor& self, DimnameList names) {
const auto self_names = self.names();
TORCH_CHECK(self_names.size() == names.size(),
"refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
" because they have a different number of dims (",
self_names.size(), " and ", names.size(), " respectively).");
check_names_valid_for(self, names);
for (const auto idx : c10::irange(self_names.size())) {
const auto& self_name = self_names[idx];
const auto& out_name = names[idx];
if (self_name == out_name || self_name.isWildcard()) {
continue;
}
if (out_name.isWildcard()) {
TORCH_CHECK(false,
"refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
" because ", self_name, " is more specific than ", out_name, " at index ",
idx);
}
TORCH_CHECK(false,
"refine_names: cannot coerce Tensor", self_names, " to Tensor", names,
" because ", self_name, " is different from ", out_name, " at index ",
idx);
TORCH_INTERNAL_ASSERT(false); // done handling errors
}
auto result = self.alias();
internal_set_names_inplace(result, names);
return result;
}
// [Alignment rules]
// Aligns `tensor` to names with the following rules:
// 1) Check that tensor.names is a subsequence (not necessarily contiguous) of `names`.
// 2) Aligning tensor.names to names must not change the absolute position from the
// right of any unnamed dimension.
//
// is_aligning_two_tensors tunes the error message to better match the following cases:
// 1) tensor.align_to(names) (is_aligning_two_tensors=false)
// 2) torch.align_tensors([tensor, other]) (is_aligning_two_tensors=true)
static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_two_tensors) {
std::vector<int64_t> expanded_sizes = aligned_size(
tensor.sizes(),
tensor.names(),
names,
is_aligning_two_tensors);
auto result = tensor.rename(nullopt).view(expanded_sizes);
at::internal_set_names_inplace(result, names);
return result;
}
static int64_t countUnset(std::bitset<kMaxNamedTensorDim> set, int64_t up_to_idx) {
int64_t result = 0;
for (const auto i : c10::irange(up_to_idx)) {
if (!set.test(i)) result++;
}
return result;
}
// Handles `tensor.align_to(*order)` in the case where there is an ellipsis.
//
// Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')`
// We expand the `...` to "all unmentioned dimensions, in the order which they
// appear in the original tensor."
//
// `order` is passed in **without** the ellipsis name. This is because ellipsis
// is not a valid name in cpp right now. Future work should be done on making
// ellipsis a valid name.
//
// `ellipsis_idx` is where the ellipsis occurs in the Python call.
// In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and
// ellipsis_idx = 1.
Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) {
const auto tensor_names = tensor.names();
const auto tensor_sizes = tensor.sizes();
const auto tensor_strides = tensor.strides();
const auto tensor_dim = tensor.sizes().size();
constexpr int64_t not_found = -1;
// General strategy.
//
// Step 1: We compute the following 3 things:
// 1. How many names the ellipsis should expand to
// 2. Which names in `tensor.names` are not mentioned in `order`.
// 3. Where names in `order` occur in tensor, if at all.
//
// Step 2: Compute the new sizes/strides/names.
// First, determine the ndim of the output tensor (this is not obvious)
// by counting the number of names in `tensor` that are not in `order`.
// Next, fill in output sizes/strides/names by using `order` and knowledge
// of which dimensions in `tensor` are unmentioned in `order`.
std::bitset<kMaxNamedTensorDim> order_has_tensor_name;
// tensor_idx_for[i] = j means that the ith name in `order`
// appears in the jth element of tensor.
std::vector<int64_t> tensor_idx_for(order.size(), not_found);
for (const auto order_idx : c10::irange(order.size())) {
const auto name = order[order_idx];
TORCH_CHECK(name.isBasic(),
"align_to: the desired order of dimensions cannot contain a None name, got ",
order);
auto it = std::find(tensor_names.begin(), tensor_names.end(), name);
if (it == tensor_names.end()) {
continue;
}
auto idx_in_tensor = std::distance(tensor_names.begin(), it);
tensor_idx_for[order_idx] = idx_in_tensor;
order_has_tensor_name.set(idx_in_tensor);
}
const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim);
const auto out_dim = num_ellipsis_names + order.size();
// Step 2: Now that we know the size of the output tensor, we can use the
// metadata obtained from Step 1 to fill in the new sizes/strides/names
std::vector<int64_t> new_sizes(out_dim, 1);
std::vector<int64_t> new_strides(out_dim, 0);
std::vector<Dimname> new_names(out_dim, Dimname::wildcard());
auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) {
new_sizes[out_dim] = tensor_sizes[tensor_dim];
new_strides[out_dim] = tensor_strides[tensor_dim];
new_names[out_dim] = tensor_names[tensor_dim];
};
// Fill in the non-ellipsis dimensions
for (const auto order_idx : c10::irange(static_cast<int64_t>(order.size()))) {
auto out_idx = order_idx;
if (order_idx >= ellipsis_idx) {
out_idx = order_idx + num_ellipsis_names;
}
const auto tensor_idx = tensor_idx_for[order_idx];
if (tensor_idx == not_found) {
// We are adding a new size-one dimension
new_names[out_idx] = order[order_idx];
continue;
}
setNewSizesStridesNamesFor(out_idx, tensor_idx);
}
// Fill in the ellipsis dimensions
for (const auto tensor_idx : c10::irange(tensor_dim)) {
if (order_has_tensor_name.test(tensor_idx)) {
continue;
}
setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx);
ellipsis_idx++;
}
check_names_valid_for(out_dim, new_names);
Tensor result;
{
NoNamesGuard guard;
result = tensor.as_strided(new_sizes, new_strides);
}
internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false);
return result;
}
Tensor align_to(const Tensor& tensor, DimnameList names) {
auto tensor_names = tensor.names();
auto tensor_sizes = tensor.sizes();
auto tensor_strides = tensor.strides();
std::vector<int64_t> new_sizes(names.size(), 1);
std::vector<int64_t> new_strides(names.size(), 0);
for (const auto idx : c10::irange(tensor_names.size())) {
const auto& dim = tensor_names[idx];
TORCH_CHECK(dim.isBasic(),
"align_to: All input dims must be named. Found unnamed dim at index ",
idx, " of Tensor", tensor_names);
auto it = std::find(names.begin(), names.end(), dim);
TORCH_CHECK(it != names.end(),
"align_to: Cannot find dim ", dim, " from Tensor", names,
" in desired alignment ", names, ".");
int64_t new_idx = std::distance(names.begin(), it);
new_sizes[new_idx] = tensor_sizes[idx];
new_strides[new_idx] = tensor_strides[idx];
}
Tensor result;
{
NoNamesGuard guard;
result = tensor.as_strided(new_sizes, new_strides);
}
internal_set_names_inplace(result, names);
return result;
}
Tensor align_as(const Tensor& tensor, const Tensor& other) {
return native::align_to(tensor, other.names());
}
static std::vector<Tensor> align_tensors_to(TensorList tensors, DimnameList names) {
std::vector<Tensor> result;
result.reserve(tensors.size());
for (const auto& tensor : tensors) {
result.emplace_back(align(tensor, names, /*is_aligning_two_tensors=*/true));
}
return result;
}
std::vector<Tensor> align_tensors(TensorList tensors) {
auto longest_dim = std::max_element(
tensors.begin(), tensors.end(),
[](const Tensor& a, const Tensor& b) {
return a.dim() < b.dim();
});
return align_tensors_to(tensors, longest_dim->names());
}
// Misc. Dimname overloads that don't have homes. Maybe we should move
// all of them here or autogenerate them because they look so similar.
Tensor gather(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad) {
reportNYIDimnameOverload("gather");
}
Tensor& gather_out(const Tensor& self, Dimname dim, const Tensor& index, bool sparse_grad, Tensor& result) {
reportNYIDimnameOverload("gather");
}
Tensor index_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) {
reportNYIDimnameOverload("index_add");
}
static Tensor& index_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) {
reportNYIDimnameOverload("index_add");
}
static Tensor& index_add_out(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar& alpha, Tensor& result) {
reportNYIDimnameOverload("index_add");
}
Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
return at::index_fill(self, dimname_to_position(self, dim), index, source);
}
Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
return self.index_fill_(dimname_to_position(self, dim), index, source);
}
Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
return at::index_fill(self, dimname_to_position(self, dim), index, source);
}
Tensor& index_fill_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
return self.index_fill_(dimname_to_position(self, dim), index, source);
}
Tensor index_copy(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("index_copy");
}
Tensor& index_copy_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("index_copy");
}
Tensor& index_select_out(const Tensor& self, Dimname dim, const Tensor& index, Tensor& out) {
reportNYIDimnameOverload("index_select");
}
Tensor index_select(const Tensor& self, Dimname dim, const Tensor& index) {
reportNYIDimnameOverload("index_select");
}
Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("scatter");
}
static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("scatter");
}
Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
reportNYIDimnameOverload("scatter");
}
static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) {
reportNYIDimnameOverload("scatter");
}
Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("scatter_add");
}
static Tensor& scatter_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
reportNYIDimnameOverload("scatter_add");
}
std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, c10::optional<bool> stable, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
reportNYIDimnameOverload("sort");
}
std::tuple<Tensor&, Tensor&> sort_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
reportNYIDimnameOverload("sort");
}
std::tuple<Tensor, Tensor> sort(const Tensor& self, c10::optional<bool> stable, Dimname dim, bool keepdim) {
reportNYIDimnameOverload("sort");
}
std::tuple<Tensor, Tensor> sort(const Tensor& self, Dimname dim, bool keepdim) {
reportNYIDimnameOverload("sort");
}
Tensor& squeeze_(Tensor& self, Dimname dim) {
reportNYIDimnameOverload("squeeze");
}
Tensor squeeze(const Tensor& self, Dimname dim) {
return at::squeeze(self, dimname_to_position(self, dim));
}
} // namespace at::native