blob: e360586b729baa62b876d4ed72b5fdd24e12e4e1 [file] [log] [blame]
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/ATen.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
DEFINE_DISPATCH(flatten_indices_stub);
} // namespace at::native
namespace at::sparse {
// NOTE [ Flatten Sparse Indices ]
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
// indices tensor. E.g.,
// input = [[2, 4, 0],
// [3, 1, 10]]
// full_size = [2, 12]
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
//
// In other words, assuming that each `indices[i, :]` is a valid index to a
// tensor `t` of shape `full_size`. This returns the corresponding indices to
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
// if forceClone is true, the result will forced to be a clone of self.
// if force_clone is true, the result will forced to be a clone of self.
Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) {
int64_t sparse_dim = indices.size(0);
if (sparse_dim == 1) {
if (force_clone) {
return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
} else {
return indices.squeeze(0);
}
} else {
if (!indices.numel()) {
return at::zeros({indices.size(1)}, indices.options().dtype(kLong));
}
return at::native::flatten_indices_stub(indices.device().type(), indices, full_size.slice(0, sparse_dim));
}
}
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
// except this one allows partial flatten: only flatten on specified dims. Note that
// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
// Also if input indices is already coalesced, the flattened indices will also be sorted.
//
// args:
// indices: sparse tensor indices
// sizes: sparse tensor sizes
// dims_to_flatten: a list of dim index to flatten
//
// Ex1:
// indices = [[2, 4, 0],
// [3, 1, 3]]
// sizes = [2, 12]
// dims_to_flatten = [0, 1]
// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
//
// Ex2:
// dims_to_flatten = [1]
// new_indices = [ 3, 1, 3 ] # uncoalesced
Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
Tensor new_indices = at::zeros({indices.size(1)}, indices.options());
for (auto d : dims_to_flatten) {
new_indices.mul_(sizes[d]);
new_indices.add_(indices.select(0, d));
}
return new_indices;
}
Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
/*
Find the CSR representation for a row `indices` from the COO format
Inputs:
`indices` is the row pointer from COO indices
`dim` is the row dimensionality
`nnz` is the number of non-zeros
Output:
`csr` is a compressed row array in a CSR format
*/
Tensor csr = at::zeros({dim + 1}, kLong);
// TODO: eliminate this conditional when zero-size dims supported correctly
if (nnz > 0) {
auto csr_accessor = csr.accessor<int64_t, 1>();
// Convert the sparse matrix to CSR format
at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t h, hp0, hp1;
for (const auto i : c10::irange(start, end)) {
hp0 = indices[i];
hp1 = (i+1 == nnz) ? dim : indices[i+1];
if (hp0 != hp1) {
for (h = hp0; h < hp1; h++) {
csr_accessor[h+1] = i+1;
}
}
}
});
}
return csr;
}
Tensor zeros_like_with_indices(const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.is_sparse());
return at::_sparse_coo_tensor_with_dims_and_tensors(
t.sparse_dim(),
t.dense_dim(),
t.sizes(),
t._indices().clone(),
at::zeros({1}, t._values().options()).expand_as(t._values()),
t.options(),
t.is_coalesced());
}
Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options) {
const auto max_size = *std::max_element(sizes.begin(), sizes.end());
const auto max_size_arange = at::arange(max_size, options);
std::vector<Tensor> stack;
stack.reserve(sizes.size());
for (size_t i=0; i < sizes.size(); i++) {
Tensor a = max_size_arange.narrow(-1, 0, sizes[i]);
for (size_t j=0; j < sizes.size(); j++) {
if (i != j) {
a.unsqueeze_(j);
}
}
stack.push_back(a.expand(sizes));
}
return at::stack(stack).flatten(1, -1);
}
} // namespace at::sparse