blob: a8adfeb7ac10fe1f41ea667ac5a2e7fb951f8e36 [file] [log] [blame]
#include <TH/THTensor.hpp>
#include <algorithm>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <ATen/native/Resize.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <algorithm>
#include <vector>
namespace at {
namespace native {
Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) {
TORCH_CHECK(shape_tensor.dim() == 1);
std::vector<int64_t> shape;
auto accessor = shape_tensor.accessor<int64_t, 1>();
for (size_t i = 0; i < shape_tensor.numel(); ++i) {
shape.push_back(accessor[i]);
}
return self.reshape(IntArrayRef(shape));
}
Tensor _shape_as_tensor(const Tensor& self) {
auto options = TensorOptions(at::kLong).is_variable(self.options().is_variable());
return at::tensor(self.sizes(), options);
}
std::vector<Tensor> broadcast_tensors(TensorList tensors) {
return expand_outplace(tensors);
}
static void check_cat_no_zero_dim(TensorList tensors) {
for(size_t i = 0; i < tensors.size(); ++i) {
auto& t = tensors[i];
TORCH_CHECK(t.dim() > 0,
"zero-dimensional tensor (at position ", i, ") cannot be concatenated");
}
}
Tensor & cat_out(Tensor & result, TensorList tensors, int64_t dim) {
check_cat_no_zero_dim(tensors);
dim = legacy_cat_wrap_dim(dim, tensors);
return at::_cat_out(result, tensors, dim);
}
static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) {
if (s1.size() != s2.size()) {
return false;
}
for (int64_t i = 0; i < s1.size(); ++i) {
if (i != dim_except && s1[i] != s2[i]) {
return false;
}
}
return true;
}
// Check to see if the shape of tensors is compatible
// for being concatenated along a given dimension.
static void check_cat_sparse_dims(Tensor const &t,
int64_t pos /* used only for debug messages */,
IntArrayRef sizes,
int64_t wrapped,
int64_t sparse_dim,
int64_t dense_dim) {
TORCH_CHECK(t.is_sparse(),
"Concatenating sparse tensors, but a dense tensor was found at position ", pos, ".");
TORCH_CHECK(sizes_match_except(sizes, t.sizes(), wrapped),
"All tensors must have the same shape: ", sizes, " (except in the concatenating dimension),"
" but found shape: ", t.sizes(), " at position ", pos, ".");
TORCH_CHECK(t.sparse_dim() == sparse_dim && t.dense_dim() == dense_dim,
"All tensors must have the same sparse_dim and dense_dim: ", sparse_dim, ", ", dense_dim,
", but tensor at position ", pos, " has ", t.sparse_dim(), ", ", t.dense_dim(), ".");
}
static Tensor cat_sparse(TensorList tensors, int64_t dim) {
std::vector<Tensor> indices;
std::vector<Tensor> values;
int64_t wrapped = maybe_wrap_dim(dim, tensors[0].dim());
int64_t sparse_dim = tensors[0].sparse_dim();
int64_t dense_dim = tensors[0].dense_dim();
IntArrayRef sizes = tensors[0].sizes();
if (wrapped < sparse_dim) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
indices.push_back(t._indices());
values.push_back(t._values());
}
Tensor idxs = at::cat(indices, 1);
Tensor vals = at::cat(values, 0);
// We now need to move the indices of each
// input tensor up along `dim` by an appropriate amount.
// E.g., if t1 has indices [[2,3,4],[5,6,7]],
// and sizes [10, 7]
// then torch.cat((t1,t1,t1),1) should have indices
// [[2,3,4,2,3,4,2,3,4],[5,6,7,12,13,14,19,20,21]],
// so we need to increase idxs[1][3:6] by 7
// and idxs[1][6:9] by 14.
int64_t col = 0;
int64_t cumulative_offset = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
int64_t this_piece_size = t._nnz();
// cumulative_offset is zero for the first piece, so
// don't waste time doing this operation unless i > 0.
if (i > 0) {
idxs[wrapped].narrow(0, col, this_piece_size) += cumulative_offset;
}
cumulative_offset += t.size(wrapped);
col += this_piece_size;
}
auto sizes_copy = sizes.vec();
sizes_copy[wrapped] = cumulative_offset;
return native::sparse_coo_tensor(idxs, vals, sizes_copy, tensors[0].options());
}
else {
// Catting along a dense dimension requires us to create new values.
// For illustration, consider the sparse 3d tensors t1 and t2,
// given by t1 = [[[1,2],[3,4]], ... (zeros) ..., [[5,6],[7,8]]]
// and t2 = [... (zeros) ..., [[9, 10], [11,12]], ... (zeros) ...],
// Their concatenation along dimension 2 is:
// [[[1,2,0,0],[3,4,0,0]], ... (zeros) ..., [[0,0,9,10],[0,0,11,12]], ... (zeros) ..., [[5,6,0,0],[7,8,0,0]]]
//
// Their values tensors are, respectively,
// [[[1,2],[3,4]],[[5,6],[7,8]]] and [[[9,10],[11,12]]].
//
// and so the values tensor of their concatenation along dim 2 will be:
// [[[1,2,0,0],[3,4,0,0]],[[5,6,0,0],[7,8,0,0]],[[0,0,9,10],[0,0,11,12]]]
//
// which we can get by taking the values tensor of each tensor, catting it with zeros of the appropriate size on the left and right,
// and then catting all those results together.
// The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting.
int64_t values_dim = wrapped - sparse_dim + 1;
// The final size along the catted dimension.
int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) {
return l + r._values().size(values_dim);
});
auto zeros_sizes = tensors[0]._values().sizes().vec();
int64_t cumulative_size = 0;
std::vector<Tensor> vals_pieces;
std::vector<Tensor> idxs_pieces;
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
// dimension 0 of values corresponds to the number of values,
// rather than to any logical dimension of the sparse tensor.
zeros_sizes[0] = t._values().size(0);
zeros_sizes[values_dim] = cumulative_size;
cumulative_size += t._values().size(values_dim);
auto z1 = native::zeros(zeros_sizes, t._values().options());
zeros_sizes[values_dim] = total_size - cumulative_size;
auto z2 = native::zeros(zeros_sizes, t._values().options());
vals_pieces.push_back(native::cat({z1, t._values(), z2}, values_dim));
idxs_pieces.push_back(t._indices());
}
auto sizes_copy = sizes.vec();
sizes_copy[wrapped] = total_size;
// This can create an uncoalesced tensor
return native::sparse_coo_tensor(native::cat(idxs_pieces, 1), native::cat(vals_pieces), sizes_copy, tensors[0].options());
}
}
Tensor cat(TensorList tensors, int64_t dim) {
if (tensors.size() > 0 &&
tensors[0].is_sparse()) {
return cat_sparse(tensors, dim);
}
check_cat_no_zero_dim(tensors);
dim = legacy_cat_wrap_dim(dim, tensors);
return at::_cat(tensors, dim);
}
std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
TORCH_CHECK(self.dim() > 0,
"chunk expects at least a 1-dimensional tensor");
TORCH_CHECK(chunks > 0,
"chunk expects `chunks` to be greater than 0, got: ", chunks);
int64_t split_size = (self.size(dim) + chunks - 1) / chunks;
// We need to call split_with_sizes in the case where split_size and dimension size are 0, because
// a call to split would discard the number of chunks (because we can have an arbitrary number of
// 0-sized chunks adding up to 0). So, call split_with_sizes with the correct number of chunks,
// eventually we will do this for all cases.
if (split_size == 0 && self.size(dim) == 0) {
std::vector<int64_t> split_sizes(chunks, split_size);
split_sizes[chunks - 1] = split_size - (split_size * chunks - self.size(dim));
return self.split_with_sizes(split_sizes, dim);
} else {
return self.split(split_size, dim);
}
}
Tensor diagflat(const Tensor& self, int64_t offset) {
return self.contiguous().view(-1).diag(offset);
}
Tensor diagonal(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
int64_t nDims = self.dim();
int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
int64_t diag_size;
int64_t storage_offset = self.storage_offset();
// compute storage offset and size for the diagonal
// for positive values of offset (above the main diagonal)
// "leftmost columns" (along dim2) are dropped
// for negative values of offset (below the main diagonal)
// "topmost rows" (along dim1) are dropped.
// Note that we invert +/- in the second to absorb the negative
// sign in the offset.
if (offset >= 0) {
diag_size = std::max<int64_t>(std::min(self.size(dim1), self.size(dim2)-offset), 0);
} else {
diag_size = std::max<int64_t>(std::min(self.size(dim1)+offset, self.size(dim2)), 0);
}
// NumPy allows you to specify offsets "off the end"; let's just be careful not to
// set a ridiculous storage_offset in that case (technically it shouldn't matter
// because there are no elements in the tensor, but let's be kosher).
if (diag_size == 0) {
// skip
} else if (offset >= 0) {
storage_offset += offset * self.stride(dim2);
} else {
storage_offset -= offset * self.stride(dim1);
}
// construct new size and stride: we drop dim1 and dim2 (maximum first for not changing the index of the minumum)
// the new ("joint") dimension is appended to the end of the shape / stride to match numpy semantics
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
sizes.erase(sizes.begin() + std::max(dim1, dim2));
strides.erase(strides.begin() + std::max(dim1, dim2));
sizes.erase(sizes.begin() + std::min(dim1, dim2));
strides.erase(strides.begin() + std::min(dim1, dim2));
sizes.push_back(diag_size);
strides.push_back(self.stride(dim1)+self.stride(dim2));
// return view with new parameters
return self.as_strided(sizes, strides, storage_offset);
}
Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
int64_t nDims = self.dim() + 1;
int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
int64_t new_dim_len = std::abs(offset) + self.size(-1);
auto sizes = self.sizes().vec();
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1, dim2), new_dim_len);
sizes.insert(sizes.begin() + std::max(dim1, dim2), new_dim_len);
auto result = at::zeros(sizes, self.options());
auto diag = result.diagonal(offset, dim1, dim2);
diag.copy_(self);
return result;
}
Tensor expand(const Tensor& self, IntArrayRef size, bool implicit) {
// [expand implicit]
// The implicit flag is set to true for any expand calls inserted by broadcast
// operators in ExpandUtils.h This flag is recorded by the tracer to
// distinguish between expands inserted by broadcasts and those explicitly
// requested by the user, because it is legal to remove implicit expands
// from the graph, but not legal to remove the explicit ones.
TORCH_CHECK(size.size() >= (size_t)self.dim(),
"expand(", self.type(), "{", self.sizes(), "}, size=", size,
"): the number of sizes provided (", size.size(), ") ",
"must be greater or equal to the number of dimensions in the tensor (",
self.dim(), ")");
std::vector<int64_t> expandedSizes;
std::vector<int64_t> expandedStrides;
std::tie(expandedSizes, expandedStrides) = inferExpandGeometry(self.sizes(), self.strides(), size);
return self.as_strided(expandedSizes, expandedStrides);
}
Tensor expand_as(const Tensor& self, const Tensor& other) {
return self.expand(other.sizes());
}
Tensor sum_to_size(const Tensor& self, IntArrayRef size) {
TORCH_CHECK(is_expandable_to(size, self.sizes()),
"size {", size, "} is not expandable to size {", self.sizes(), "}.");
return sum_to(self, size);
}
Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto tid = self.type_id();
auto result = detail::make_tensor<TensorImpl>(Storage(self.storage()), tid);
setStrided(result, size, stride, storage_offset);
return result;
}
Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto tid = self.type_id();
auto result = detail::make_tensor<QTensorImpl>(Storage(self.storage()), tid, get_qtensorimpl(self)->quantizer());
setStrided(result, size, stride, storage_offset);
return result;
}
Tensor &as_strided_(Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
setStrided(self, size, stride, storage_offset);
return self;
}
Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
int64_t allDim = self.dim();
int64_t end = start+length;
TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(dim >= 0 && dim < allDim,
"Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, ".");
TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim),
"Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").")
Tensor indices = self._indices();
int64_t sparse_dim = self.sparse_dim();
std::vector<int64_t> new_sizes = self.sizes().vec();
new_sizes[dim] = length;
Tensor new_values;
Tensor new_indices;
if (dim < sparse_dim) {
Tensor mask = (indices[dim] >= start).__and__((indices[dim] < end));
new_indices = indices.masked_select(mask).view({sparse_dim, -1});
new_indices[dim].sub_(start);
Tensor nzIndices = mask.nonzero().view(-1);
new_values = self._values().index_select(0, nzIndices);
} else {
/* This means we are narrowing on a dense dim, which is in effect just a
regular narrow on _values() */
new_indices = indices;
int64_t dense_dim = dim - sparse_dim + 1;
new_values = self._values().narrow_copy(dense_dim, start, length);
}
auto newTensor = at::sparse_coo_tensor(new_indices, new_values, new_sizes);
return newTensor._coalesced_(self.is_coalesced());
}
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length){
return self.narrow(dim, start, length).clone();
}
Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
}
TORCH_CHECK(length >= 0 && start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice(self, dim, start, start + length, 1);
}
Tensor permute(const Tensor& self, IntArrayRef dims) {
auto nDims = self.dim();
TORCH_CHECK(dims.size() == (size_t)nDims,
"number of dims don't match in permute");
auto oldSizes = self.sizes();
auto oldStrides = self.strides();
std::vector<int64_t> newSizes(nDims);
std::vector<int64_t> newStrides(nDims);
std::vector<bool> seen(nDims);
for (int64_t i = 0; i < nDims; i++) {
auto dim = maybe_wrap_dim(dims[i], nDims);
TORCH_CHECK(!seen[dim],
"repeated dim in permute");
seen[dim] = true;
newSizes[i] = oldSizes[dim];
newStrides[i] = oldStrides[dim];
}
return self.as_strided(newSizes, newStrides);
}
Tensor repeat(const Tensor& self, IntArrayRef repeats) {
TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
// Add new leading dimensions to the tensor if the
// number of target dimensions is larger than the
// number of source dimensions.
int64_t num_new_dimensions = repeats.size() - self.dim();
std::vector<int64_t> padded_size(num_new_dimensions, 1);
padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
std::vector<int64_t> target_size(repeats.size());
for(size_t idx = 0; idx < repeats.size(); ++idx) {
target_size[idx] = padded_size[idx] * repeats[idx];
}
Tensor xtensor = self.expand(padded_size);
Tensor result = at::empty(target_size, self.options());
Tensor urtensor = at::alias(result);
for (int64_t i = 0; i < xtensor.dim(); ++i) {
// can't unfold with step 0, so make sure step is at least 1
// (it doesn't matter what it is in that case, because the size is 0).
urtensor = urtensor.unfold(i, xtensor.size(i), std::max<int64_t>(xtensor.size(i), 1));
}
urtensor.copy_(xtensor.expand_as(urtensor));
return result;
}
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());
if (self.is_mkldnn()) {
return at::mkldnn_reshape(self, shape);
}
if (auto stride = THTensor_compute_stride(self.sizes(), self.strides(), shape)) {
return self.as_strided(shape, *stride);
}
return at::_unsafe_view(self.clone(), shape);
}
Tensor reshape_as(const Tensor& self, const Tensor& other) {
return self.reshape(other.sizes());
}
Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
if (ndim == 0) {
AT_INDEX_ERROR("select() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto size = self.size(dim);
if (index < -size || index >= size) {
AT_INDEX_ERROR("select(): index ", index, " out of range for tensor of size ",
self.sizes(), " at dimension ", dim);
}
if (index < 0) {
index += size;
}
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
auto storage_offset = self.storage_offset() + index * strides[dim];
sizes.erase(sizes.begin() + dim);
strides.erase(strides.begin() + dim);
return self.as_strided(sizes, strides, storage_offset);
}
Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
AT_INDEX_ERROR("slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
if (start < 0) {
start += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
return self.as_strided(sizes, strides, storage_offset);
}
std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
int64_t dim_size = self.size(dim);
TORCH_CHECK(split_size > 0 || self.size(dim) == 0,
"split_size can only be 0 if dimension size is 0, "
"but got dimension size of ", dim_size);
// if split_size is 0 and dimension size is 0, there is 1 split.
int64_t num_splits = 1;
if (split_size != 0) {
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
// (returns a single split). We might want to error here, but keep it for BC.
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
}
std::vector<Tensor> splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
for (int64_t i = 0; i < num_splits; ++i) {
auto length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = self.narrow(dim, i * split_size, length);
}
return splits;
}
std::vector<Tensor> split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
int64_t dim_size = self.size(dim);
int64_t num_splits = split_sizes.size();
std::vector<Tensor> splits(num_splits);
int64_t start_idx = 0;
int64_t i;
for (i = 0; i < num_splits; ++i) {
auto length = split_sizes[i];
TORCH_CHECK(length >= 0,
"split_with_sizes expects split_sizes have only non-negative ",
"entries, but got split_sizes=", split_sizes);
splits[i] = self.narrow(dim, start_idx, length);
start_idx += length;
}
TORCH_CHECK(start_idx == dim_size,
"split_with_sizes expects split_sizes to sum exactly to ", dim_size,
" (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes);
return splits;
}
static inline std::vector<Tensor> get_stack_inputs(TensorList tensors, int64_t dim) {
std::vector<Tensor> inputs(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
inputs[i] = tensors[i].unsqueeze(dim);
}
return inputs;
}
Tensor stack(TensorList tensors, int64_t dim) {
TORCH_CHECK(tensors.size() > 0,
"stack expects a non-empty TensorList");
dim = maybe_wrap_dim(dim, tensors[0].dim() + 1);
return at::cat(get_stack_inputs(tensors, dim), dim);
}
Tensor& stack_out(Tensor& result, TensorList tensors, int64_t dim) {
TORCH_CHECK(tensors.size() > 0,
"stack expects a non-empty TensorList");
dim = maybe_wrap_dim(dim, tensors[0].dim() + 1);
return at::cat_out(result, get_stack_inputs(tensors, dim), dim);
}
static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
int64_t nsparse_dim = self.sparse_dim();
TORCH_CHECK(dim0 < nsparse_dim && dim1 < nsparse_dim,
"sparse transpose: transposed dimensions must be sparse ",
"Got sparse_dim: ", nsparse_dim, ", d0: ", dim0, ", d1: ", dim1);
if (self._indices().numel() == 0 && self._values().numel() == 0) {
auto sizes = self.sizes().vec();
std::swap(sizes[dim0], sizes[dim1]);
at::sparse::get_sparse_impl(self)->raw_resize_(self.sparse_dim(), self.dense_dim(), sizes);
} else {
auto indices = self._indices();
auto row0 = indices.select(0, dim0);
auto row1 = indices.select(0, dim1);
// swap row0 and row1
auto tmp = at::zeros_like(row0);
tmp.copy_(row0);
row0.copy_(row1);
row1.copy_(tmp);
self._coalesced_(false);
auto sizes = self.sizes().vec();
std::swap(sizes[dim0], sizes[dim1]);
at::sparse::get_sparse_impl(self)->raw_resize_(self._indices().size(0), self._values().dim() - 1, sizes);
}
return self;
}
Tensor & transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
auto ndims = self.dim();
dim0 = maybe_wrap_dim(dim0, ndims);
dim1 = maybe_wrap_dim(dim1, ndims);
if (dim0 == dim1) {
return self;
}
if (self.is_sparse()) {
return sparse_transpose_(self, dim0, dim1);
}
auto strides = self.strides().vec();
auto sizes = self.sizes().vec();
std::swap(strides[dim0], strides[dim1]);
std::swap(sizes[dim0], sizes[dim1]);
return self.as_strided_(sizes, strides);
}
Tensor transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
auto ndims = self.dim();
dim0 = maybe_wrap_dim(dim0, ndims);
dim1 = maybe_wrap_dim(dim1, ndims);
if (dim0 == dim1) {
return self;
}
if (self.is_sparse()) {
Tensor self_clone = self.clone(); // yes, this is what THS does
return sparse_transpose_(self_clone, dim0, dim1);
}
auto strides = self.strides().vec();
auto sizes = self.sizes().vec();
std::swap(strides[dim0], strides[dim1]);
std::swap(sizes[dim0], sizes[dim1]);
return self.as_strided(sizes, strides);
}
static void check_t(const Tensor& self, const char *fn) {
if (self.is_sparse()) {
int64_t sparse_dim = self.sparse_dim();
int64_t dense_dim = self.dense_dim();
TORCH_CHECK(sparse_dim <= 2 && dense_dim == 0,
fn, " expects a tensor with <= 2 sparse and 0 dense dimensions, but got ",
sparse_dim, " sparse and ", dense_dim, " dense dimensions");
} else {
TORCH_CHECK(self.dim() <= 2,
fn, " expects a tensor with <= 2 dimensions, but self is ", self.dim(), "D");
}
}
Tensor t(const Tensor & self) {
check_t(self, "t()");
return self.transpose(0, self.dim() < 2 ? 0 : 1);
}
Tensor & t_(Tensor & self) {
check_t(self, "t_()");
return self.transpose_(0, self.dim() < 2 ? 0 : 1);
}
std::tuple<std::vector<int64_t>, std::vector<int64_t> >
inferSqueezeGeometry(const Tensor &tensor) {
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
for(int64_t d = 0; d < tensor.dim(); d++) {
if(tensor.sizes()[d] != 1) {
sizes.push_back(tensor.sizes()[d]);
strides.push_back(tensor.strides()[d]);
}
}
return std::make_tuple(sizes, strides);
}
std::tuple<std::vector<int64_t>, std::vector<int64_t> >
inferSqueezeGeometry(const Tensor& tensor, int64_t dim) {
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
for(int64_t d = 0; d < tensor.dim(); d++) {
if(d != dim || tensor.sizes()[dim] != 1) {
sizes.push_back(tensor.sizes()[d]);
strides.push_back(tensor.strides()[d]);
}
}
return std::make_tuple(sizes, strides);
}
std::tuple<std::vector<int64_t>, std::vector<int64_t> >
inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
auto sizes = tensor.sizes().vec();
auto strides = tensor.strides().vec();
int64_t new_stride = dim >= tensor.dim() ? 1 : sizes[dim] * strides[dim];
sizes.insert(sizes.begin() + dim, 1);
strides.insert(strides.begin() + dim, new_stride);
return std::make_tuple(sizes, strides);
}
Tensor squeeze(const Tensor& self) {
auto g = inferSqueezeGeometry(self);
return self.as_strided(std::get<0>(g), std::get<1>(g));
}
Tensor squeeze(const Tensor& self, int64_t dim) {
int64_t dims = self.dim();
dim = maybe_wrap_dim(dim, dims);
if (dims == 0 || self.sizes()[dim] != 1) {
return self.as_strided(self.sizes(), self.strides());
}
auto g = inferSqueezeGeometry(self, dim);
return self.as_strided(std::get<0>(g), std::get<1>(g));
}
Tensor & squeeze_(Tensor& self) {
auto g = inferSqueezeGeometry(self);
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}
Tensor & squeeze_(Tensor& self, int64_t dim) {
int64_t dims = self.dim();
dim = maybe_wrap_dim(dim, self.dim());
if (dims == 0 || self.sizes()[dim] != 1) {
return self.as_strided_(self.sizes(), self.strides());
}
auto g = inferSqueezeGeometry(self, dim);
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}
// _unsafe_view() differs from view() in that the returned tensor isn't treated
// as a view for the purposes of automatic differentiation. (It's not listed in
// VIEW_FUNCTIONS in gen_autograd.py). It's only safe to use if the `self` tensor
// is temporary. For example, the viewed tensor here (a + b) is discarded immediately
// after viewing:
//
// res = at::_unsafe_view(a + b, size);
//
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.
Tensor _unsafe_view(const Tensor& self, IntArrayRef size) {
return self.view(size);
}
static Tensor unsqueeze_sparse(Tensor const &self, int64_t dim /* should already be wrapped */) {
int64_t sparse_dim = self.sparse_dim();
int64_t dense_dim = self.dense_dim();
auto indices = self._indices();
auto sizes = self.sizes().vec();
sizes.insert(sizes.begin() + dim, 1);
if (dim <= sparse_dim) {
auto new_indices = native::cat({
indices.narrow(0, 0, dim),
native::zeros({1, indices.size(1)}, indices.options().dtype(kLong)),
indices.narrow(0, dim, indices.size(0) - dim)
});
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim + 1, dense_dim, sizes, new_indices, self._values(), self.options());
} else {
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim + 1, sizes, indices, self._values().unsqueeze(dim - sparse_dim + 1), self.options());
}
}
Tensor unsqueeze(const Tensor& self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim() + 1);
if (self.is_sparse()) {
return unsqueeze_sparse(self, dim);
} else {
auto g = inferUnsqueezeGeometry(self, dim);
return self.as_strided(std::get<0>(g), std::get<1>(g));
}
}
Tensor & unsqueeze_(Tensor& self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim() + 1);
auto g = inferUnsqueezeGeometry(self, dim);
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}
Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
start_dim = maybe_wrap_dim(start_dim, self.dim());
end_dim = maybe_wrap_dim(end_dim, self.dim());
TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
if (start_dim == end_dim) {
return self;
}
// We don't want to infer_size on the entire shape, because that can give us an extra degree
// of freedom we don't want; for example, consider shape [0, 1, 3, 0], with start_dim=1, end_dim=2.
// It's clear we want result shape [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1
// can take on any value and satisfy the constraints.
auto slice_numel = prod_intlist(self.sizes().slice(start_dim, end_dim - start_dim + 1));
std::vector<int64_t> shape;
shape.reserve(self.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(self.size(i));
}
shape.push_back(slice_numel);
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
shape.push_back(self.size(i));
}
return self.reshape(shape);
}
Tensor view_as(const Tensor& self, const Tensor& other) {
return self.view(other.sizes());
}
int64_t numel(const Tensor& self) {
return self.unsafeGetTensorImpl()->numel();
}
std::vector<Tensor> unbind(const Tensor &self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim());
int64_t size = self.size(dim);
std::vector<Tensor> tensors(size);
for (int i = 0; i < size; i++) {
tensors[i] = self.select(dim, i);
}
return tensors;
}
std::vector<Tensor> meshgrid(TensorList tensors) {
int64_t size = tensors.size();
TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
std::vector<int64_t> shape(size);
for(int64_t i = 0; i < size; i++) {
switch (tensors[i].dim()) {
case 0:
shape[i] = 1;
break;
case 1:
shape[i] = tensors[i].size(0);
break;
default:
AT_ERROR("Expected scalar or 1D tensor in the tensor list but got: ", tensors[i]);
}
}
for(int64_t i = 0; i < size - 1; i++){
TORCH_CHECK(tensors[i].dtype() == tensors[i+1].dtype(), "meshgrid expects all tensors to have the same dtype");
TORCH_CHECK(tensors[i].device() == tensors[i+1].device(), "meshgrid expects all tensors to have the same device");
}
std::vector<Tensor> grids;
for(int64_t i = 0; i < size; i++) {
std::vector<int64_t> view_shape(size, 1);
view_shape[i] = -1;
grids.push_back(tensors[i].view(view_shape).expand(shape));
}
return grids;
}
// Numpy-style `a.T`: returns the tensor
// with dims reversed
Tensor T(const Tensor &self) {
int64_t n = self.dim();
DimVector transpose_dims;
for (int64_t i = n - 1; i >= 0; --i) {
transpose_dims.push_back(i);
}
return self.permute(transpose_dims);
}
}
}