| // Indexing tensors by by tensors |
| // |
| // This corresponds to "advanced indexing" in NumPy. The two operations are: |
| // |
| // index(Tensor self, indices) -> Tensor |
| // index_put_(Tensor self, indices, value, accumulate=false) |
| // |
| // The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte |
| // tensors (boolean masks) are expanded to long tensors via nonzero(). Null |
| // tensors signify that the dimension is not indexed. |
| // |
| // All indexes are broadcast together and iterated as *one*. From NumPy: |
| // |
| // result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], |
| // ..., ind_N[i_1, ..., i_M]] |
| // |
| // Note 1: ByteTensors expand to index as many dimensions as there are in the |
| // mask. |
| // |
| // Note 2: The behavior is more complicated when the index tensors are not all |
| // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index |
| // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]] |
| // |
| // The code contains two implementations of indexing. The more efficient |
| // implementation treats indexing like an elementwise operation over the |
| // tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does |
| // not work for index_put_ with accumulate=True. The other implementation |
| // combines the indexed tensors into a single linear index that is used |
| // with Tensor.put_. This is used for index_put_ with accumulate=True. |
| // |
| // The more efficient implementation takes the following steps for the |
| // above operation: |
| // |
| // 1) Broadcast ind_1, ind_2, ind_3 together to a common shape |
| // 2) Record x.stride(i) for each indexed dimension `i` |
| // 3) Replace the indexed subspace of `x` with the shape of the corresponding |
| // subspace of `result` but with stride 0 |
| // 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so |
| // that their shape is compatible with the result shape |
| // |
| // The CPU or CUDA kernel then computes element-wise over the broadcasted |
| // and restrided result, x, ind_1, ind_2, etc.: |
| // |
| // result[...] = *(&x[...] + |
| // ind_1[...] * x.stride(1) + |
| // ind_2[...] * x.stride(2) + |
| // ...) |
| // |
| // where & and * represent the C-style address-of and indirection operations. |
| |
| #include <ATen/native/TensorAdvancedIndexing.h> |
| #include <ATen/native/IndexKernel.h> |
| #include <ATen/native/IndexingUtils.h> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/NativeFunctions.h> |
| #include <ATen/ExpandUtils.h> |
| #include <ATen/MemoryOverlap.h> |
| #include <ATen/native/TensorAdvancedIndexingUtils.h> |
| #include <ATen/native/TensorIterator.h> |
| #include <ATen/native/BinaryOps.h> |
| #include <ATen/native/Copy.h> |
| #include <ATen/native/Resize.h> |
| #include <ATen/native/ScatterGatherChecks.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/NumericUtils.h> |
| #include <ATen/TensorSubclassLikeUtils.h> |
| |
| #include <c10/util/irange.h> |
| #include <c10/util/Unroll.h> |
| |
| #include <algorithm> |
| #include <functional> |
| #include <numeric> |
| #include <vector> |
| |
| namespace at { |
| namespace meta { |
| |
| native::SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce, bool use_new_options = false) { |
| if (use_new_options) { |
| if (reduce == "sum") { |
| return native::SCATTER_GATHER_OP::REDUCE_ADD; |
| } else if (reduce == "prod") { |
| return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY; |
| } else if (reduce == "mean") { |
| return native::SCATTER_GATHER_OP::REDUCE_MEAN; |
| } else if (reduce == "amax") { |
| return native::SCATTER_GATHER_OP::REDUCE_MAXIMUM; |
| } else if (reduce == "amin") { |
| return native::SCATTER_GATHER_OP::REDUCE_MINIMUM; |
| } else { |
| TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin."); |
| } |
| } else { |
| if (reduce == "add") { |
| return native::SCATTER_GATHER_OP::REDUCE_ADD; |
| } else if (reduce == "multiply") { |
| return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY; |
| } else { |
| TORCH_CHECK(false, "reduce argument must be either add or multiply.") |
| } |
| } |
| } |
| |
| TORCH_META_FUNC(gather) |
| (const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { |
| const Tensor& result = maybe_get_output(0); |
| int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim()); |
| |
| // Memory overlap checks need to be done after resizing (if required) is done. |
| // But it only makes sense to do these checks when result was defined, hence |
| // the boolean variable `check_result` here. |
| // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832 |
| // and https://github.com/pytorch/pytorch/issues/63837 |
| bool check_result = result.defined(); |
| set_output_raw_strided(0, index.sizes(), {}, self.options()); |
| if (check_result) { |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, self); |
| at::assert_no_partial_overlap(result, index); |
| } |
| |
| auto is_index_empty = index.numel() == 0; |
| if (!is_index_empty) { |
| TORCH_CHECK( |
| index.scalar_type() == at::ScalarType::Long, |
| "gather", "(): Expected dtype int64 for index" |
| ); |
| } |
| if (is_index_empty) return; |
| at::native::gather_shape_check(self, wrapped_dim, index); |
| } |
| |
| template <bool use_new_options = false, typename Meta> |
| void scatter_meta_impl( |
| Meta& meta, |
| const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const c10::optional<Tensor>& src = nullopt, |
| const c10::optional<c10::string_view> reduce = nullopt) { |
| int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim()); |
| at::native::scatter_gather_dtype_check("scatter", self, index, src); |
| at::native::scatter_shape_check(self, wrapped_dim, index, src); |
| auto output = meta.maybe_get_output(0); |
| |
| if (output.defined()) { |
| at::assert_no_internal_overlap(output); |
| at::assert_no_overlap(output, index); |
| if (src.has_value()) { |
| at::assert_no_overlap(output, src.value()); |
| } |
| } |
| |
| meta.set_output_raw_strided(0, self.sizes(), {}, self.options()); |
| if (reduce.has_value()) { |
| // Check if we have a valid reduce operator. |
| get_operator_enum(reduce.value(), use_new_options); |
| } |
| } |
| |
| TORCH_META_FUNC2(scatter, src) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { |
| scatter_meta_impl(*this, self, dim, index, src); |
| } |
| |
| TORCH_META_FUNC2(scatter, value) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) { |
| scatter_meta_impl(*this, self, dim, index); |
| } |
| |
| TORCH_META_FUNC2(scatter, reduce) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const c10::string_view reduce) { |
| scatter_meta_impl(*this, self, dim, index, src, reduce); |
| } |
| |
| TORCH_META_FUNC2(scatter, value_reduce) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Scalar& src, |
| const c10::string_view reduce) { |
| scatter_meta_impl(*this, self, dim, index, nullopt, reduce); |
| } |
| |
| TORCH_META_FUNC(scatter_add) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { |
| scatter_meta_impl(*this, self, dim, index, src, "add"); |
| } |
| |
| TORCH_META_FUNC2(scatter_reduce, two) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const c10::string_view reduce, |
| bool include_self) { |
| (void) include_self; |
| scatter_meta_impl</*use_new_options=*/true>(*this, self, dim, index, src, reduce); |
| } |
| |
| TORCH_PRECOMPUTE_META_FUNC(index_copy) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source) { |
| dim = maybe_wrap_dim(dim, self.dim()); |
| |
| const Tensor& result = maybe_get_output(0); |
| |
| // Memory overlap checks need to be done after resizing (if required) is done. |
| // But it only makes sense to do these checks when result was defined, hence |
| // the boolean variable `check_result` here. |
| // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832 |
| // and https://github.com/pytorch/pytorch/issues/63837 |
| bool check_result = result.defined(); |
| set_output_raw_strided(0, self.sizes(), {}, self.options()); |
| if (check_result) { |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, index); |
| at::assert_no_overlap(result, source); |
| } |
| |
| TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")"); |
| |
| int64_t numIndices = index.numel(); |
| if (source.dim() == 0 && numIndices != 1) { |
| TORCH_CHECK_INDEX(false, "index_copy_(): When source is scalar, index should have one element (got ", numIndices, ")"); |
| } else if ((source.dim() != self.dim()) && (source.dim() != 0 && self.dim() != 0)) { |
| TORCH_CHECK_INDEX(false, "index_copy_(): When source and destination are not scalars, their dimensionality must match. Source dimensionality (", |
| source.dim(), "), destination dimensionality (", self.dim(), ")"); |
| } |
| |
| TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_copy_(): Expected a long tensor for index, but got ", index.scalar_type()); |
| TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_copy_(): self and source expected to have the same dtype, but got (self) ", self.scalar_type(), " and (source) ", source.scalar_type()); |
| TORCH_CHECK(self.device() == source.device() && self.device() == index.device(), |
| "index_copy_(): self, index and source expected to be in the same device, but got (self) ", |
| self.device(), ", (index) ", index.device(), ", and (source) ", source.device()); |
| |
| // Check that source and destination slices have the same size |
| auto selfSlicedSizes = self.sizes().vec(); |
| if (selfSlicedSizes.size() > 0) { |
| selfSlicedSizes.erase(selfSlicedSizes.begin() + dim); |
| } |
| auto sourceSlicedSizes = source.sizes().vec(); |
| if (sourceSlicedSizes.size() > 0) { |
| sourceSlicedSizes.erase(sourceSlicedSizes.begin() + dim); |
| } |
| if (selfSlicedSizes.size() != sourceSlicedSizes.size() || |
| !std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(), |
| sourceSlicedSizes.begin())) { |
| std::stringstream ss; |
| ss << "index_copy_(): Source/destination tensor must have same slice shapes. "; |
| ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim; |
| ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0."; |
| TORCH_CHECK(false, ss.str()); |
| } |
| TORCH_CHECK_INDEX(source.dim() == 0 || numIndices == source.size(dim), |
| "index_copy_(): Number of indices (", numIndices, ") should be equal to source.size(dim) (", source.size(dim), ")"); |
| |
| return TORCH_PRECOMPUTE_STRUCT(index_copy)().set_dim(dim); |
| } |
| |
| template <typename Meta> |
| void index_func_meta_impl( |
| Meta& meta, |
| const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& source, |
| c10::string_view func) { |
| auto numel = index.numel(); |
| |
| TORCH_CHECK_INDEX(index.dim() <= 1, func, "_(): Index is supposed to be a vector, but got dim: ", |
| index.dim(), " with type: ", index.scalar_type(), " and size: ", index.sizes()); |
| TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, |
| func, "_(): Expected dtype int32/int64 for index but got: ", index.scalar_type()); |
| TORCH_CHECK(self.scalar_type() == source.scalar_type(), |
| func, "_(): self (", self.scalar_type(), ") and source (", source.scalar_type(), |
| ") must have the same scalar type"); |
| TORCH_CHECK(dim == 0 || dim < source.dim(), |
| func, "_(): Indexing dim ", dim, " is out of bounds of the source tensor with dim ", |
| source.dim()); |
| TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)), |
| func, "_(): Number of indices (", numel, ") should be equal to source.size(dim): (", |
| source.size(dim), "), for dim: ", dim); |
| |
| auto& result = meta.maybe_get_output(0); |
| bool is_defined = result.defined(); |
| meta.set_output_raw_strided(0, self.sizes(), {}, self.options()); |
| if (is_defined) { |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, index); |
| at::assert_no_overlap(result, source); |
| } |
| |
| // A hack to run TensorIterator checks in the meta function. |
| // See comment: https://github.com/pytorch/pytorch/pull/65993#discussion_r760307417 |
| // TODO: (@krshrimali) Try inheriting from TensorIteratorBase instead. |
| if (result.device() == kMeta && result.dim() > 0) { |
| auto selfSlice = result.select(dim, 0); |
| auto sourceSlice = source.select(dim, 0); |
| auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); |
| } |
| } |
| |
| TORCH_PRECOMPUTE_META_FUNC(index_add) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha) { |
| dim = maybe_wrap_dim(dim, self.dim()); |
| index_func_meta_impl(*this, self, dim, index, source, "index_add"); |
| return TORCH_PRECOMPUTE_STRUCT(index_add)().set_dim(dim); |
| } |
| |
| TORCH_PRECOMPUTE_META_FUNC(index_reduce) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& source, |
| const c10::string_view reduce, |
| bool include_self) { |
| (void)include_self; |
| TORCH_CHECK(reduce == "prod" || reduce == "mean" || reduce == "amax" || reduce == "amin", |
| "index_reduce(): Expected reduce to be one of prod, mean, amax or amin but got ", reduce, "."); |
| dim = maybe_wrap_dim(dim, self.dim()); |
| index_func_meta_impl(*this, self, dim, index, source, "index_reduce"); |
| return TORCH_PRECOMPUTE_STRUCT(index_reduce)().set_dim(dim); |
| } |
| |
| } // namespace meta |
| |
| namespace native { |
| |
| DEFINE_DISPATCH(index_stub); |
| DEFINE_DISPATCH(index_fill_stub); |
| DEFINE_DISPATCH(index_copy_stub); |
| DEFINE_DISPATCH(index_put_stub); |
| DEFINE_DISPATCH(index_put_with_sort_stub); |
| DEFINE_DISPATCH(put_stub); |
| DEFINE_DISPATCH(take_stub); |
| DEFINE_DISPATCH(masked_fill_stub); |
| REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub); |
| DEFINE_DISPATCH(masked_select_serial_stub); |
| DEFINE_DISPATCH(masked_select_stub); |
| DEFINE_DISPATCH(masked_scatter_stub); |
| |
| DEFINE_DISPATCH(gather_stub); |
| DEFINE_DISPATCH(scatter_stub); |
| DEFINE_DISPATCH(scatter_fill_stub); |
| DEFINE_DISPATCH(scatter_add_stub); |
| DEFINE_DISPATCH(scatter_reduce_stub); |
| DEFINE_DISPATCH(scatter_scalar_reduce_stub); |
| DEFINE_DISPATCH(scatter_reduce_two_stub); |
| |
| static bool all_strides_match(TensorList tensors) { |
| TORCH_CHECK(tensors.size() >= 1); |
| auto strides = tensors[0].strides(); |
| for (auto& tensor : tensors.slice(1)) { |
| if (!strides.equals(tensor.strides())) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| // Replace indexed dimensions in src with stride 0 and the size of the result tensor. |
| // The offset in these dimensions is computed by the kernel using the index tensor's |
| // values and the stride of src. The new shape is not meaningful. It's used to make |
| // the shape compatible with the result tensor. |
| static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed, |
| IntArrayRef replacement_shape) { |
| auto shape = DimVector(src.sizes()); |
| auto strides = DimVector(src.strides()); |
| int64_t end = dims_before + dims_indexed; |
| shape.erase(shape.begin() + dims_before, shape.begin() + end); |
| strides.erase(strides.begin() + dims_before, strides.begin() + end); |
| shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end()); |
| strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0); |
| return src.as_strided(shape, strides); |
| } |
| |
| // Add dimensions of size 1 to an index tensor so that it can be broadcast to the result |
| // shape and iterated over element-wise like the result tensor and the restrided src. |
| static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) { |
| auto orig_shape = index.sizes(); |
| auto shape = DimVector(); |
| shape.append(dims_before, 1); |
| shape.append(orig_shape.begin(), orig_shape.end()); |
| shape.append(dims_after, 1); |
| return index.reshape(shape); |
| } |
| |
| AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) |
| { |
| int64_t element_size_bytes = src.element_size(); |
| int64_t dims_before = 0, dims_after = 0, dims_indexed = 0; |
| IntArrayRef replacement_shape; |
| for (const auto dim : c10::irange(indices_list.size())) { |
| if (!indices_list[dim].defined()) { |
| if (dims_indexed == 0) { |
| dims_before++; |
| } else { |
| dims_after++; |
| } |
| } else { |
| dims_indexed++; |
| replacement_shape = indices_list[dim].sizes(); |
| indexed_sizes.push_back(src.size(dim)); |
| indexed_strides.push_back(src.stride(dim) * element_size_bytes); |
| } |
| } |
| |
| // Check if the indexed subspace contains a dim of size 0, but the replacement |
| // shape does not. This implies that an index is out of bounds, because there |
| // is no number that's a valid index for an empty tensor. Normally, out of |
| // bounds is handled in the indexing kernel, but this case fails earlier in |
| // restride_src with an unhelpful error message. |
| if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() && |
| std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) { |
| TORCH_CHECK_INDEX(false, "index is out of bounds for dimension with size 0"); |
| } |
| |
| this->dims_before = dims_before; |
| this->dims_after = dims_after; |
| this->src = restride_src(src, dims_before, dims_indexed, replacement_shape); |
| |
| for (auto& index : indices_list) { |
| if (index.defined()) { |
| indices.push_back(reshape_indexer(index, dims_before, dims_after)); |
| } |
| } |
| |
| // For CUDA tensors, force all index tensors to have the same striding to |
| // simplify the CUDA kernel. |
| if (indices.size() >= 2 && this->src.device().type() == kCUDA) { |
| if (!all_strides_match(indices)) { |
| for (auto & indice : indices) { |
| indice = indice.contiguous(); |
| } |
| } |
| } |
| } |
| |
| static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) { |
| TORCH_CHECK(is_expandable_to(value.sizes(), info.src.sizes()), "shape mismatch: value tensor of shape ", value.sizes(), |
| " cannot be broadcast to indexing result of shape ", info.src.sizes()); |
| TORCH_CHECK(value.scalar_type() == info.src.scalar_type(), |
| "Index put requires the source and destination dtypes match, " |
| "got ", info.src.scalar_type(), " for the destination " |
| "and ", value.scalar_type(), " for the source."); |
| TensorIteratorConfig config; |
| // info.src is restrided by restride_src with 0 strided dimensions |
| config.set_check_mem_overlap(false); |
| config.resize_outputs(false); |
| config.check_all_same_dtype(false); |
| config.add_output(info.src); |
| config.add_input(value); |
| for (auto& index : info.indices) { |
| config.add_input(index); |
| } |
| return config.build(); |
| } |
| |
| static TensorIterator make_index_iterator(const AdvancedIndex& info) { |
| TensorIteratorConfig config; |
| config.set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .declare_static_dtype_and_device(info.src.scalar_type(), info.src.device()) |
| .add_owned_output(Tensor()) |
| .add_input(info.src); |
| for (auto& index : info.indices) { |
| config.add_input(index); |
| } |
| return config.build(); |
| } |
| |
| static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& result) { |
| TensorIteratorConfig config; |
| // info.src is a restrided view of result |
| config.set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .add_output(result) |
| .add_input(info.src); |
| for (auto& index : info.indices) { |
| config.add_input(index); |
| } |
| return config.build(); |
| } |
| |
| Tensor index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) { |
| TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
| |
| auto info = make_info(self, indices); |
| auto iter = make_index_iterator(info); |
| index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); |
| return iter.output(); |
| } |
| |
| Tensor quantized_index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) { |
| TORCH_INTERNAL_ASSERT( |
| self.qscheme() == c10::kPerTensorAffine || |
| self.qscheme() == c10::kPerTensorSymmetric, |
| "Indexing is only supported for per-Tensor quantized Tensors."); |
| |
| // For now, this is a naive implementation which does dq -> index -> q. |
| // TODO(future PR): improve performance by removing the copies. |
| const auto& self_dq = self.dequantize(); |
| |
| TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
| |
| auto info = make_info(self_dq, indices); |
| auto iter = make_index_iterator(info); |
| index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); |
| at::Tensor res = iter.output(); |
| |
| return at::quantize_per_tensor( |
| res, self.q_scale(), self.q_zero_point(), self.scalar_type()); |
| } |
| |
| Tensor& index_out(Tensor& result, const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) { |
| TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, self); |
| // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) |
| for (const c10::optional<Tensor>& index: indices) { |
| if (index.has_value()) { |
| at::assert_no_overlap(result, *index); |
| } |
| } |
| |
| auto info = make_info(self, indices); |
| auto iter = make_index_out_iterator(info, result); |
| index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); |
| return result; |
| } |
| |
| Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { |
| // See note [Writing Nondeterministic Operations] |
| // Nondeterministic when index contains duplicate entries and we do not accumulate |
| // If we accumulate on GPU, we use atomicGPUAdd, which is non-deterministic |
| if (!accumulate || (accumulate && self.device().type() == DeviceType::CUDA)) { |
| at::globalContext().alertNotDeterministic("put_"); |
| } |
| |
| // Type and device checks |
| TORCH_CHECK(index.scalar_type() == ScalarType::Long, "put_(): Expected a long tensor for index, but got ", index.scalar_type()) |
| TORCH_CHECK(self.scalar_type() == source.scalar_type(), "put_(): self and source expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and source.dtype = ", source.scalar_type()); |
| TORCH_CHECK(self.device() == source.device() && self.device() == index.device(), |
| "put_(): self, index and source expected to be in the same device, but got self.device = ", |
| self.device(), ", index.device = ", index.device(), ", and source.device = ", source.device()); |
| |
| // index checks |
| TORCH_CHECK_INDEX(source.numel() == index.numel(), "put_(): Expected source and index to have the same number of elements, but got source.numel() = ", source.numel(), ", index.numel() = ", index.numel()); |
| TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "put_(): Tried to put elements into an empty tensor"); |
| |
| at::assert_no_internal_overlap(self); |
| at::assert_no_overlap(self, index); |
| at::assert_no_overlap(self, source); |
| |
| // Early return |
| if (index.numel() == 0) { |
| return self; |
| } |
| |
| auto index_reshaped = index.reshape(source.sizes()); |
| // Do not iterate over self, we will compute the offsets manually |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .add_input(source) |
| .add_input(index_reshaped) |
| .build(); |
| |
| put_stub(iter.device_type(), iter, self, accumulate); |
| |
| return self; |
| } |
| |
| Tensor put(const Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { |
| return self.clone(at::MemoryFormat::Preserve).put_(index, source, accumulate); |
| } |
| |
| Tensor index_put(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate) { |
| return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); |
| } |
| |
| Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate, const bool unsafe) { |
| TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
| if (at::has_internal_overlap(self) == MemOverlap::YES) { |
| TORCH_WARN( |
| "Use of index_put_ on expanded tensors is deprecated. " |
| "Please clone() the tensor before performing this operation. " |
| "This also applies to advanced indexing e.g. tensor[indices] = tensor"); |
| } |
| if (!accumulate) { |
| auto masked_fill_dispatch = canDispatchToMaskedFill(self, indices, value); |
| if (std::get<0>(masked_fill_dispatch)) { |
| return self.masked_fill_(std::get<1>(masked_fill_dispatch), value.item()); |
| } |
| } |
| auto value_ = value; |
| if (value.device() != self.device() && value.numel() == 1 && value.dim() == 0) { |
| value_ = value.to(self.device()); |
| } |
| at::assert_no_overlap(self, value); |
| // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) |
| for (const c10::optional<Tensor>& index: indices) { |
| if (index.has_value()) { |
| at::assert_no_overlap(self, *index); |
| } |
| } |
| if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) { |
| TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ", |
| value_.device(), " for value tensor"); |
| index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe); |
| return self; |
| } |
| |
| auto info = make_info(self, indices); |
| auto iter = make_index_put_iterator(info, value_); |
| index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate); |
| return self; |
| } |
| |
| Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) { |
| // Type and device checks |
| TORCH_CHECK(index.scalar_type() == ScalarType::Long, "take(): Expected a long tensor for index, but got ", index.scalar_type()) |
| TORCH_CHECK(self.scalar_type() == out.scalar_type(), "take(): self and out expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and out.dtype = ", out.scalar_type()); |
| TORCH_CHECK(self.device() == out.device() && self.device() == index.device(), |
| "take(): self, index and out expected to be in the same device, but got self.device = ", |
| self.device(), ", index.device = ", index.device(), ", and out.device = ", out.device()); |
| |
| // index checks |
| TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "take(): tried to take from an empty tensor"); |
| |
| at::assert_no_internal_overlap(out); |
| at::assert_no_overlap(out, index); |
| at::assert_no_overlap(out, self); |
| |
| // Do not iterate over self, we will compute the offsets manually |
| // out is resized inside tensor_iterator |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .add_output(out) |
| .add_input(index) |
| .build(); |
| |
| // Early return after out has been resized |
| if (index.numel() == 0) { |
| return out; |
| } |
| |
| take_stub(iter.device_type(), iter, self); |
| |
| return out; |
| } |
| |
| Tensor take(const Tensor& self, const Tensor& index) { |
| auto out = at::empty(index.sizes(), self.options()); |
| at::native::take_out(self, index, out); |
| return out; |
| } |
| |
| Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) { |
| return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); |
| } |
| |
| TORCH_IMPL_FUNC(index_copy_out) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Tensor& result) { |
| if (!result.is_same(self)) result.copy_(self); |
| |
| // See Note [Enabling Deterministic Operations] |
| if (result.is_cuda() && globalContext().deterministicAlgorithms()){ |
| torch::List<c10::optional<Tensor>> indices; |
| indices.reserve(dim + 1); |
| for (const auto i: c10::irange(dim)) { |
| (void)i; |
| indices.emplace_back(); |
| } |
| indices.emplace_back(index); |
| result.index_put_(indices, source, false); |
| return; |
| } |
| |
| // Handle the case when self / source is 0-dim |
| Tensor result_nonzero = result.dim() == 0 ? result.unsqueeze(0) : result; |
| Tensor source_nonzero = source.dim() == 0 ? source.unsqueeze(0) : source; |
| |
| // The only difference between the following tensor iterator and that of index_fill_ is that |
| // this one has also source as an input. We should refactor it when if constexpr is available (C++17) |
| |
| // Prepare `index` for TensorIterator. |
| // It is restrided to be broadcastable over `self` in TensorIterator. |
| auto index_sizes = std::vector<int64_t>(result_nonzero.dim(), 1); |
| auto index_strides = std::vector<int64_t>(result_nonzero.dim(), 0); |
| index_sizes[dim] = index.numel(); |
| index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar |
| auto index_restrided = index.as_strided( |
| index_sizes, index_strides); |
| |
| // Prepare `result` for TensorIterator. |
| // Restride `result` to not advance in dimension `dim`. |
| // We do not use squash_dim here because `index` will |
| // need to advance in this dimension. |
| // Note that self_sizes[dim] is set to index.numel(). |
| // This is done so that self_sizes[dim] and index_sizes[dim] |
| // match as required by TensorIterator (input shape should |
| // strictly broadcast over output shape, i.e. |
| // output.shape[i] >= input.shape[i] for i in range(dims)). |
| auto result_sizes = result_nonzero.sizes().vec(); |
| auto result_strides = result_nonzero.strides().vec(); |
| result_sizes[dim] = index.numel(); |
| result_strides[dim] = 0; |
| auto result_restrided = result_nonzero.as_strided(result_sizes, result_strides); |
| |
| auto iter = TensorIteratorConfig() |
| // We do not check for overlap because `result` is restrided |
| // with zero stride. Zero strides trigger memory overlap assert |
| // within TensorIterator. |
| .set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(result_restrided) |
| .add_input(index_restrided) |
| .add_input(source_nonzero) |
| .build(); |
| |
| auto result_dim_size = result_nonzero.size(dim); |
| auto result_dim_stride = result_nonzero.stride(dim); |
| index_copy_stub( |
| iter.device_type(), |
| iter, |
| dim, |
| result_dim_size, |
| result_dim_stride); |
| } |
| |
| // Not calling into index_reduce_func_impl because of a different dtype dispatch |
| TORCH_IMPL_FUNC(index_add_cpu_out) |
| (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) { |
| if (!result.is_same(self)) result.copy_(self); |
| auto numel = index.numel(); |
| |
| auto index_contig = index.contiguous(); |
| |
| if (result.dim() > 1) { |
| // Equivalent to: |
| // for (const auto i : c10::irange(numel)) { |
| // auto selfSlice = self.select(dim, index_data[i]); |
| // auto sourceSlice = source.select(dim, i); |
| // selfSlice.add_(sourceSlice); |
| // } |
| // But much faster as this reuses the iterator from add_ |
| if (numel == 0) { |
| return; |
| } |
| auto selfSlice = result.select(dim, 0); |
| auto sourceSlice = source.select(dim, 0); |
| auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); |
| auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type()); |
| auto self_dim_size = result.size(dim); |
| auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); |
| |
| AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); |
| auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes; |
| auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes; |
| iter.unsafe_replace_operand(0, self_data); |
| iter.unsafe_replace_operand(1, self_data); |
| iter.unsafe_replace_operand(2, source_data); |
| add_stub(iter.device_type(), iter, alpha); |
| } |
| }); |
| } |
| else { |
| TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); |
| |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, |
| result.scalar_type(), "index_add_", [&result, &source, &dim, &index_contig, &numel, &alpha] { |
| auto alpha_value = alpha.to<scalar_t>(); |
| auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); |
| auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); |
| // TODO: Maybe TensorAccessor can be used here? |
| auto* result_ptr = result.data_ptr<scalar_t>(); |
| auto* source_ptr = source.data_ptr<scalar_t>(); |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_", |
| [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self"); |
| scalar_t *self_ip = result_ptr + self_i * result_stride; |
| *self_ip += *(source_ptr + i * source_stride) * alpha_value; |
| } |
| }); |
| }); |
| } |
| } |
| |
| void index_reduce_func_impl( |
| const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& source, |
| bool include_self, |
| const Tensor& result, |
| const SCATTER_GATHER_OP& op) { |
| if (!result.is_same(self)) result.copy_(self); |
| if (!include_self) { |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, at::ScalarType::BFloat16, |
| self.scalar_type(), "index_reduce_func_exclude_input_init", [&] { |
| scalar_t init_val; |
| switch (op) { |
| case SCATTER_GATHER_OP::REDUCE_MULTIPLY: |
| init_val = (scalar_t)1; |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MAXIMUM: |
| init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity() |
| : std::numeric_limits<scalar_t>::lowest(); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MINIMUM: |
| init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity() |
| : std::numeric_limits<scalar_t>::max(); |
| break; |
| default: |
| init_val = (scalar_t)0; |
| break; |
| } |
| // index_fill_ requires index to be a LongTensor |
| result.index_fill_(dim, index.to(at::ScalarType::Long), init_val); |
| }); |
| } |
| |
| auto numel = index.numel(); |
| |
| auto index_contig = index.contiguous(); |
| |
| if (result.dim() > 1) { |
| // Equivalent to: |
| // for (const auto i : c10::irange(numel)) { |
| // auto selfSlice = self.select(dim, index_data[i]); |
| // auto sourceSlice = source.select(dim, i); |
| // selfSlice.op_(sourceSlice); |
| // } |
| // But much faster as this reuses the iterator from the binary op |
| if (numel == 0) { |
| return; |
| } |
| auto selfSlice = result.select(dim, 0); |
| auto sourceSlice = source.select(dim, 0); |
| auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); |
| auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type()); |
| auto self_dim_size = result.size(dim); |
| auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); |
| |
| AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); |
| auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes; |
| auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes; |
| iter.unsafe_replace_operand(0, self_data); |
| iter.unsafe_replace_operand(1, self_data); |
| iter.unsafe_replace_operand(2, source_data); |
| |
| switch (op) { |
| case SCATTER_GATHER_OP::REDUCE_MULTIPLY : |
| mul_stub(iter.device_type(), iter); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MINIMUM : |
| minimum_stub(iter.device_type(), iter); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MAXIMUM : |
| maximum_stub(iter.device_type(), iter); |
| break; |
| default : |
| add_stub(iter.device_type(), iter, 1); |
| break; |
| } |
| } |
| }); |
| |
| if (op == SCATTER_GATHER_OP::REDUCE_MEAN) { |
| auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); |
| counts.index_add_(dim, index, at::ones_like(source)); |
| counts.masked_fill_(counts == 0, 1); |
| result.div_(counts); |
| } |
| } |
| else { |
| TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); |
| auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, |
| result.scalar_type(), "index_func_", [&result, &source, &dim, &index_contig, &numel, &op, &counts] { |
| auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); |
| auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); |
| auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim); |
| // TODO: Maybe TensorAccessor can be used here? |
| auto* result_ptr = result.data_ptr<scalar_t>(); |
| auto* source_ptr = source.data_ptr<scalar_t>(); |
| auto counts_ptr = counts.data_ptr<scalar_t>(); |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_", |
| [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self"); |
| scalar_t *self_ip = result_ptr + self_i * result_stride; |
| scalar_t *count_ip; |
| scalar_t val; |
| switch (op) { |
| case SCATTER_GATHER_OP::REDUCE_MEAN : |
| *self_ip += *(source_ptr + i * source_stride); |
| count_ip = counts_ptr + self_i * counts_stride; |
| *count_ip += 1; |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MULTIPLY : |
| *self_ip *= *(source_ptr + i * source_stride); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MINIMUM : |
| val = *(source_ptr + i * source_stride); |
| *self_ip = at::_isnan<scalar_t>(val) ? val : std::min(*self_ip, val); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MAXIMUM : |
| val = *(source_ptr + i * source_stride); |
| *self_ip = at::_isnan<scalar_t>(val) ? val : std::max(*self_ip, val); |
| break; |
| default: |
| break; |
| } |
| } |
| }); |
| }); |
| if (op == SCATTER_GATHER_OP::REDUCE_MEAN) { |
| counts.masked_fill_(counts == 0, 1); |
| result.div_(counts); |
| } |
| } |
| } |
| |
| TORCH_IMPL_FUNC(index_reduce_cpu_out) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& source, |
| const c10::string_view reduce, |
| bool include_input, |
| const Tensor& result) { |
| TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time."); |
| auto op = meta::get_operator_enum(reduce, true); |
| index_reduce_func_impl(self, dim, index, source, include_input, result, op); |
| } |
| |
| // Check that indices fall within dimension array size |
| // Avoid redispatch call to min/max |
| template <typename IndexType> |
| static void check_indexarray_range( |
| const IndexType* indices, |
| int64_t n, |
| IndexType indexing_axis_dim) { |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| TORCH_CHECK( |
| 0 <= idx && idx < indexing_axis_dim, |
| "INDICES element is out of DATA bounds, id=", |
| idx, |
| " axis_dim=", |
| indexing_axis_dim); |
| } |
| } |
| |
| Tensor & index_select_out_cpu_dim1_( |
| Tensor & result_contig, const Tensor & self, const Tensor & index_contig) { |
| |
| auto self_contig = self.contiguous(); |
| const caffe2::TypeMeta dataType = self_contig.dtype(); |
| size_t item_bytesize = dataType.itemsize(); |
| |
| auto out = static_cast<char*>(result_contig.data_ptr()); |
| |
| auto src_base = static_cast<const char*>(self_contig.data_ptr()); |
| |
| auto self_sizes = self_contig.sizes(); |
| auto outer_dims_product = c10::size_to_dim_(1, self_sizes); |
| auto block_size = c10::size_from_dim_(2, self_sizes); |
| auto block_bytesize = block_size * item_bytesize; |
| |
| auto src_indexing_axis_dim = self_sizes[1]; |
| auto src_batch_bytesize = self_sizes[1] * block_bytesize; |
| auto N = index_contig.numel(); |
| |
| auto gathered_batch_bytesize = N * block_bytesize; |
| |
| AT_DISPATCH_INDEX_TYPES( |
| index_contig.scalar_type(), "batch_index_select_compute", [&]() { |
| |
| const auto* idxs = index_contig.data_ptr<index_t>(); |
| check_indexarray_range<index_t>(idxs, N, src_indexing_axis_dim); |
| |
| // Special-case single-float copy for efficiency |
| if (self.scalar_type() == ScalarType::Float && block_size == 1) { |
| for (const auto batch : c10::irange(outer_dims_product)) { |
| const float* src_floats = |
| (const float*)(src_base + batch * src_batch_bytesize); |
| float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); |
| |
| for (const auto i : c10::irange(N)) { |
| auto idx = idxs[i]; |
| dst_floats[i] = src_floats[idx]; |
| } |
| } |
| } else { |
| // outer_dims_product specifies how many times we repeat inner dimensions, |
| // so we just iterate over it to cover all outer dimensions. |
| for (const auto batch : c10::irange(outer_dims_product)) { |
| for (const auto i : c10::irange(N)) { |
| auto idx = idxs[i]; |
| auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; |
| auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; |
| memcpy(dst, src, block_bytesize); |
| } |
| } |
| } |
| }); |
| return result_contig; |
| } |
| |
| Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor & index, Tensor & result) { |
| if (self.is_quantized()) { |
| TORCH_CHECK( |
| self.qscheme() == kPerTensorAffine, |
| "Only per_tensor quantized quantized tensors are supported by index_select.") |
| } |
| dim = maybe_wrap_dim(dim, self.dim()); |
| auto numel = index.numel(); |
| TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); |
| TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); |
| TORCH_CHECK(self.scalar_type() == result.scalar_type(), |
| "index_select(): self and result must have the same scalar type"); |
| TORCH_CHECK(dim == 0 || dim < self.dim(), |
| "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, self); |
| at::assert_no_overlap(result, index); |
| auto result_size = self.sizes().vec(); |
| if (self.dim() > 0) { |
| result_size[dim] = numel; |
| } |
| at::native::resize_output(result, result_size); |
| |
| auto index_contig = index.contiguous(); |
| |
| if (self.dim() > 1) { |
| if (numel == 0) { |
| return result; |
| } |
| if (self.numel() == 0) { |
| auto src_indexing_axis_dim = self.size(dim); |
| TORCH_CHECK(src_indexing_axis_dim > 0, |
| "index_select(): self indexing axis dim should be positive"); |
| AT_DISPATCH_INDEX_TYPES( |
| index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() { |
| const auto* idxs = index_contig.data_ptr<index_t>(); |
| check_indexarray_range<index_t>(idxs, numel, src_indexing_axis_dim); |
| }); |
| return result; |
| } |
| |
| if (dim == 1 && result.is_contiguous()) { |
| // fast pass |
| return index_select_out_cpu_dim1_(result, self, index_contig); |
| } |
| |
| auto selfSlice = self.select(dim, 0); |
| auto resultSlice = result.select(dim, 0); |
| auto selfSlice_data = selfSlice.data_ptr(); |
| auto resultSlice_data = resultSlice.data_ptr(); |
| auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type()); |
| auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); |
| auto self_dim_size = self.size(dim); |
| auto slice_size = selfSlice.numel(); |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(resultSlice) |
| .add_input(selfSlice) |
| .build(); |
| |
| auto grain_size = at::internal::GRAIN_SIZE; |
| auto outer_loop = |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, |
| &result_stride_bytes](int64_t start, int64_t end) { |
| auto sub_iter = TensorIterator(iter); |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", |
| [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, |
| &resultSlice_data, &result_stride_bytes] () { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(start, end)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); |
| auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes; |
| auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes; |
| sub_iter.unsafe_replace_operand(0, result_data); |
| sub_iter.unsafe_replace_operand(1, self_data); |
| copy_stub(sub_iter.device_type(), sub_iter, false); |
| }; |
| }); |
| }; |
| |
| // parallel on inner loop in case the slice is large enough; |
| // otherwise parallel on outer loop |
| if (slice_size >= grain_size) { |
| outer_loop(0, numel); |
| } else { |
| // use a fast loop when self and result are contiguous and of the same data type |
| if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) { |
| auto slice_size_bytes = slice_size * elementSize(self.scalar_type()); |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| at::parallel_for(0, numel, grain_size / slice_size, |
| [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, |
| &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) { |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", |
| [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, |
| &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(start, end)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); |
| auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes; |
| auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes; |
| memcpy(result_data, self_data, slice_size_bytes); |
| } |
| }); |
| }); |
| } else { |
| at::parallel_for(0, numel, grain_size / slice_size, outer_loop); |
| } |
| } |
| } else { |
| TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| if(self.is_quantized()){ |
| AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] { |
| auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); |
| auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); |
| auto self_data_ptr = self.data_ptr<scalar_t>(); |
| auto result_data_ptr = result.data_ptr<scalar_t>(); |
| auto self_numel = self.numel(); |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_", |
| [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); |
| scalar_t *self_ip = self_data_ptr + self_i * self_stride; |
| *(result_data_ptr + i * result_stride) = *self_ip; |
| } |
| }); |
| }); |
| } else { |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, |
| self.scalar_type(), "index_select", [&index_contig, &self, &result, &dim, &numel] { |
| auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); |
| auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); |
| |
| auto self_data_ptr = self.data_ptr<scalar_t>(); |
| auto result_data_ptr = result.data_ptr<scalar_t>(); |
| auto self_numel = self.numel(); |
| AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", |
| [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { |
| auto index_data = index_contig.data_ptr<index_t>(); |
| for (const auto i : c10::irange(numel)) { |
| auto self_i = index_data[i]; |
| TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); |
| scalar_t *self_ip = self_data_ptr + self_i * self_stride; |
| *(result_data_ptr + i * result_stride) = *self_ip; |
| } |
| }); |
| }); |
| } |
| } |
| |
| return result; |
| } |
| |
| Tensor index_select_cpu_(const Tensor & self, int64_t dim, const Tensor & index) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::native::index_select_out_cpu_(self, dim, index, result); |
| } |
| |
| Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tensor & index) { |
| TORCH_CHECK(self.qscheme() == kPerTensorAffine, |
| "Only per_tensor quantized quantized tensors are supported by index_select.") |
| Tensor result = at::empty_quantized({0}, self); |
| return at::native::index_select_out_cpu_(self, dim, index, result); |
| } |
| |
| Tensor index_select_backward(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) { |
| // for composite compliance, use out-of-place variant of |
| // `index_add` if index tensor is a Tensor Subclass. |
| if (isTensorSubclassLike(index)) { |
| return grad.new_zeros(self_sizes, grad.options()).index_add(dim, index, grad); |
| } |
| return grad.new_zeros(self_sizes, grad.options()).index_add_(dim, index, grad); |
| } |
| |
| Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) { |
| at::NoNamesGuard guard; |
| |
| TORCH_CHECK_INDEX( |
| index.scalar_type() == ScalarType::Long, |
| "index_fill_(): Expected dtype int64 for index."); |
| |
| at::assert_no_overlap(self, index); |
| if (at::has_internal_overlap(self) == at::MemOverlap::YES) { |
| TORCH_WARN( |
| "Use of index_fill_ on expanded tensors is deprecated. " |
| "Please clone() the tensor before performing this operation. " |
| "This also applies to advanced indexing e.g. tensor[mask] = scalar"); |
| } |
| |
| if (!self.is_complex() && source.isComplex()) { |
| TORCH_CHECK(false, "index_fill_(): Converting complex Scalar to non-complex type is not supported"); |
| } |
| |
| // Handle the case when `self` is 0-dim |
| Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; |
| |
| dim = at::maybe_wrap_dim(dim, self_nonzero_dim); |
| TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar"); |
| |
| // Prepare `index` for TensorIterator. |
| // It is restrided to be broadcastable over `self` in TensorIterator. |
| auto index_sizes = std::vector<int64_t>(self_nonzero_dim.dim(), 1); |
| auto index_strides = std::vector<int64_t>(self_nonzero_dim.dim(), 0); |
| index_sizes[dim] = index.numel(); |
| index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar |
| auto index_restrided = index.as_strided( |
| index_sizes, index_strides); |
| |
| // Prepare `self` for TensorIterator. |
| // Restride `self` to not advance in dimension `dim`. |
| // We do not use squash_dim here because `index` will |
| // need to advance in this dimension. |
| // Note that self_sizes[dim] is set to index.numel(). |
| // This is done so that self_sizes[dim] and index_sizes[dim] |
| // match as required by TensorIterator (input shape should |
| // strictly broadcast over output shape, i.e. |
| // output.shape[i] >= input.shape[i] for i in range(dims)). |
| auto self_sizes = self_nonzero_dim.sizes().vec(); |
| auto self_strides = self_nonzero_dim.strides().vec(); |
| self_sizes[dim] = index.numel(); |
| self_strides[dim] = 0; |
| auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides); |
| |
| auto iter = TensorIteratorConfig() |
| // We do not check for overlap because `self` is restrided |
| // with zero stride. Zero strides trigger memory overlap assert |
| // within TensorIterator. |
| .set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(self_restrided) |
| .add_input(index_restrided) |
| .build(); |
| |
| auto self_dim_size = (self_nonzero_dim.sizes())[dim]; |
| auto self_dim_stride = (self_nonzero_dim.strides())[dim]; |
| index_fill_stub( |
| iter.device_type(), |
| iter, |
| dim, |
| self_dim_size, |
| self_dim_stride, |
| source); |
| |
| return self; |
| } |
| |
| Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { |
| TORCH_CHECK(source.dim() == 0, "index_fill_ only supports a 0-dimensional value tensor, but got tensor " |
| "with ", source.dim(), " dimension(s)."); |
| return self.index_fill_(dim, index, source.item()); |
| } |
| |
| Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) { |
| return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source); |
| } |
| |
| Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { |
| return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source); |
| } |
| |
| // gather_out_cpu_cuda |
| TORCH_IMPL_FUNC(gather_out) |
| (const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) { |
| if (index.numel() == 0) return; |
| dim = at::maybe_wrap_dim(dim, self.dim()); |
| gather_stub(result.device().type(), result, self, dim, index); |
| } |
| |
| Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) { |
| if (sparse_grad) { |
| return at::_gather_sparse_backward(self, dim, index, grad); |
| } |
| auto result = grad.new_zeros(self.sizes()); |
| // for composite compliance, use out-of-place variant of |
| // `scatter_add` if index tensor is a Tensor Subclass. |
| if (isTensorSubclassLike(index)) { |
| return result.scatter_add(dim, index, grad); |
| } |
| result.scatter_add_(dim, index, grad); |
| return result; |
| } |
| |
| static void scatter_reduce_exclude_self_helper( |
| const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const SCATTER_GATHER_OP& op) { |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( |
| at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, |
| self.scalar_type(), "scatter_reduce_exclude_input_init", [&] { |
| scalar_t init_val; |
| switch (op) { |
| case SCATTER_GATHER_OP::REDUCE_ADD: |
| init_val = (scalar_t)0; |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MULTIPLY: |
| init_val = (scalar_t)1; |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MAXIMUM: |
| init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity() |
| : std::numeric_limits<scalar_t>::lowest(); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MINIMUM: |
| init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity() |
| : std::numeric_limits<scalar_t>::max(); |
| break; |
| case SCATTER_GATHER_OP::REDUCE_MEAN: |
| init_val = (scalar_t)0; |
| break; |
| } |
| self.scatter_(dim, index, init_val); |
| }); |
| } |
| |
| template <bool use_new_options = false, typename T, typename ReduceStub, typename FillStub> |
| void scatter_impl( |
| const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const T& src, |
| const Tensor& out, |
| ReduceStub& reduce_stub, |
| FillStub& fill_stub, |
| const c10::optional<c10::string_view> reduce = nullopt, |
| bool reduce_includes_self = true) { |
| |
| dim = at::maybe_wrap_dim(dim, self.dim()); |
| auto mut_out = const_cast<Tensor&>(out); |
| |
| if (!self.is_same(mut_out)) { |
| mut_out.copy_(self); |
| } |
| |
| if (index.numel() == 0) return; |
| |
| if (reduce.has_value()) { |
| auto op = meta::get_operator_enum(reduce.value(), use_new_options); |
| if (!reduce_includes_self) { |
| // scatter inits for reduction to appropriate indices (used by scatter_reduce.two) |
| scatter_reduce_exclude_self_helper(mut_out, dim, index, op); |
| } |
| reduce_stub(self.device().type(), mut_out, dim, index, src, op); |
| } else { |
| fill_stub(self.device().type(), mut_out, dim, index, src); |
| } |
| } |
| |
| TORCH_IMPL_FUNC(scatter_src_out) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const Tensor& out) { |
| scatter_impl(self, dim, index, src, out, |
| scatter_reduce_stub, |
| scatter_stub); |
| } |
| |
| TORCH_IMPL_FUNC(scatter_value_out) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Scalar& value, |
| const Tensor& out) { |
| scatter_impl(self, dim, index, value, out, |
| scatter_scalar_reduce_stub, |
| scatter_fill_stub); |
| } |
| |
| TORCH_IMPL_FUNC(scatter_reduce_out) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const c10::string_view reduce, |
| const Tensor& out) { |
| scatter_impl(self, dim, index, src, out, |
| scatter_reduce_stub, |
| scatter_stub, |
| reduce); |
| } |
| |
| TORCH_IMPL_FUNC(scatter_value_reduce_out) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Scalar& value, |
| const c10::string_view reduce, |
| const Tensor& out) { |
| scatter_impl(self, dim, index, value, out, |
| scatter_scalar_reduce_stub, |
| scatter_fill_stub, |
| reduce); |
| } |
| |
| TORCH_IMPL_FUNC(scatter_add) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const Tensor& out) { |
| auto mut_out = const_cast<Tensor&>(out); |
| dim = maybe_wrap_dim(dim, self.dim()); |
| |
| if (!self.is_same(mut_out)) { |
| mut_out.copy_(self); |
| } |
| |
| if (index.numel() == 0) return; |
| |
| if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA && self.dim() == 1) { |
| TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, " |
| "but their dims are ", index.dim(), " and ", src.dim(), ", respectively"); |
| TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, " |
| "but got ", index.numel(), " versus ", src.numel()); |
| TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim); |
| torch::List<c10::optional<Tensor>> indices; |
| indices.reserve(1); |
| indices.push_back(index); |
| mut_out.index_put_(indices, src, true); |
| } else { |
| scatter_add_stub(self.device().type(), mut_out, dim, index, src); |
| } |
| } |
| |
| TORCH_IMPL_FUNC(scatter_reduce_two) |
| (const Tensor& self, |
| int64_t dim, |
| const Tensor& index, |
| const Tensor& src, |
| const c10::string_view reduce, |
| bool include_self, |
| const Tensor& out) { |
| // See issue https://github.com/pytorch/pytorch/issues/74770 |
| TORCH_WARN_ONCE("scatter_reduce() is in beta and the API may change at any time."); |
| |
| scatter_impl</*use_new_options=*/true>(self, dim, index, src, out, |
| scatter_reduce_two_stub, |
| scatter_stub, |
| reduce, |
| include_self); |
| |
| if (meta::get_operator_enum(reduce, true) == SCATTER_GATHER_OP::REDUCE_MEAN) { |
| auto ones = at::ones_like(src); |
| auto count = include_self ? at::ones_like(out) : at::zeros_like(out); |
| count.scatter_add_(dim, index, ones); |
| count.masked_fill_(count == 0, 1); |
| |
| if (out.is_floating_point() || out.is_complex()) { |
| out.div_(count); |
| } else { |
| out.div_(count, "floor"); |
| } |
| } |
| } |
| |
| Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) { |
| c10::MaybeOwned<Tensor> _mask, _self; |
| std::tie(_mask, _self) = expand_outplace(mask, self); |
| return _self->clone(at::MemoryFormat::Contiguous).masked_scatter_(*_mask, source); |
| } |
| |
| static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, const Scalar& value) { |
| NoNamesGuard guard; |
| if (mask.dtype() == ScalarType::Byte) { |
| TORCH_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ |
| "please use a mask with dtype torch.bool instead."); |
| } |
| |
| if (at::has_internal_overlap(self) == MemOverlap::YES) { |
| TORCH_WARN( |
| "Use of masked_fill_ on expanded tensors is deprecated. " |
| "Please clone() the tensor before performing this operation. " |
| "This also applies to advanced indexing e.g. tensor[mask] = scalar"); |
| } |
| at::assert_no_partial_overlap(self, mask); |
| |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) // deprecated, but not a hard error |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(self) |
| .add_input(mask) |
| .build(); |
| |
| masked_fill_stub(iter.device_type(), iter, value); |
| return self; |
| } |
| |
| Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Scalar& value) { |
| auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); |
| |
| masked_fill_impl_cpu(self, mask, value); |
| namedinference::propagate_names_if_nonempty(self, maybe_outnames); |
| return self; |
| } |
| |
| Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & value) { |
| auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); |
| TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor " |
| "with ", value.dim(), " dimension(s)."); |
| |
| masked_fill_impl_cpu(self, mask, value.item()); |
| namedinference::propagate_names_if_nonempty(self, maybe_outnames); |
| return self; |
| } |
| |
| Tensor masked_fill(const Tensor & self, const Tensor & mask, const Scalar& source) { |
| Tensor result; |
| auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill"); |
| { |
| NoNamesGuard guard; |
| c10::MaybeOwned<Tensor> _mask, _self; |
| std::tie(_mask, _self) = expand_outplace(mask, self); |
| result = _self->clone(at::MemoryFormat::Contiguous); |
| result.masked_fill_(mask, source); |
| } |
| namedinference::propagate_names_if_nonempty(result, maybe_outnames); |
| return result; |
| } |
| |
| Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) { |
| Tensor result; |
| auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill"); |
| { |
| NoNamesGuard guard; |
| c10::MaybeOwned<Tensor> _mask, _self; |
| std::tie(_mask, _self) = expand_outplace(mask, self); |
| result = _self->clone(at::MemoryFormat::Contiguous); |
| result.masked_fill_(mask, source); |
| } |
| namedinference::propagate_names_if_nonempty(result, maybe_outnames); |
| return result; |
| } |
| |
| static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) { |
| NoNamesGuard guard; |
| |
| TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool, |
| "masked_select: expected BoolTensor or ByteTensor for mask"); |
| TORCH_CHECK(self.scalar_type() == result.scalar_type(), |
| "masked_select(): self and result must have the same scalar type"); |
| |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, self); |
| at::assert_no_overlap(result, mask); |
| |
| if (mask.dtype() == at::ScalarType::Byte) { |
| TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \ |
| "please use a mask with dtype torch.bool instead."); |
| } |
| |
| c10::MaybeOwned<Tensor> _mask, _self; |
| std::tie(_mask, _self) = expand_outplace(mask, self); |
| |
| auto shape = _self->sizes(); |
| int64_t numel = _mask->sum().item().toLong(); |
| at::native::resize_output(result, {numel}); |
| if (numel == 0) { |
| return result; |
| } |
| |
| // Create strided view of result before feeding into TensorIterator |
| auto strides = DimVector(shape.size(), 0); |
| auto orig_stride = result.strides()[0]; |
| auto result_strided = result.as_strided(shape, strides); |
| |
| // serial kernel |
| // serial kernel requires that src is traversed in its logical order. However, TensorIterator might |
| // have reordered dimensions so that src would be traversed in its physical order, producing wrong |
| // answers. A sufficient condition that no reorder happened is that both _self and _mask is contiguous. |
| // If it is not satisfied, use parallel kernel that handles permutations correctly |
| bool use_serial_kernel = (self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ) && |
| _self->is_contiguous() && _mask->is_contiguous(); |
| if (use_serial_kernel) { |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) // result is intenionally zero-strided above |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(result_strided) |
| .add_input(*_self) |
| .add_input(*_mask) |
| .build(); |
| |
| masked_select_serial_stub(iter.device_type(), iter, orig_stride); |
| return result; |
| } |
| |
| // Use a prefix sum to record the output locations of the masked elements, |
| // so as to parallel with TensorIterator. |
| auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(*_mask); |
| auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong)); |
| auto mask_long_data = mask_long.data_ptr<int64_t>(); |
| auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>(); |
| // TODO: Here can only use std::partial_sum for C++14, |
| // use std::exclusive_scan when PyTorch upgrades to C++17, which have better peformance. |
| // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0); |
| std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data); |
| |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) // result is intenionally zero-strided above |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(result_strided) |
| .add_input(*_self) |
| .add_input(*_mask) |
| .add_input(mask_prefix_sum) |
| .build(); |
| |
| masked_select_stub(iter.device_type(), iter, orig_stride); |
| return result; |
| } |
| |
| Tensor & masked_select_out_cpu(const Tensor & self, const Tensor & mask, Tensor & result) { |
| namedinference::compute_broadcast_outnames(self, mask); |
| return masked_select_out_impl_cpu(result, self, mask); |
| } |
| |
| Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::native::masked_select_out_cpu(self, mask, result); |
| } |
| |
| Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Tensor& mask) { |
| // The following could just be written as `zeros_like(input).masked_scatter(mask, grad)`. |
| // However, as an optimization, we call the in-place variant of masked_scatter. |
| // Unfortunately, that doesn't allow for the broadcasting of the LHS, so we need |
| // to explicitly broadcast here (the out-of-place variant of masked_scatter |
| // implicitly handles broadcasting). |
| auto result = at::zeros_like( |
| input.expand(at::infer_size(input.sizes(), mask.sizes())), at::MemoryFormat::Preserve); |
| |
| // for composite compliance, use out-of-place variant |
| // of `masked_scatter`. |
| if (areAnyTensorSubclassLike({grad, mask})) { |
| return result.masked_scatter(mask, grad); |
| } |
| result.masked_scatter_(mask, grad); |
| return result; |
| } |
| |
| namespace { |
| |
| inline std::tuple<Tensor, Tensor, int64_t> _take_along_dim_helper( |
| const Tensor& self, |
| const Tensor& indices, |
| int64_t dim) { |
| TORCH_CHECK( |
| self.dim() == indices.dim(), |
| "torch.take_along_dim(): input and indices should have the same number of dimensions, ", |
| "but got ", self.dim(), " dimensions for input, and ", indices.dim(), " dimensions for indices") |
| TORCH_CHECK( |
| indices.scalar_type() == ScalarType::Long, |
| "torch.take_along_dim(): dtype of indices should be Long but got ", indices.scalar_type()) |
| |
| dim = at::maybe_wrap_dim(dim, self.dim()); |
| |
| DimVector self_sizes{self.sizes()}; |
| // update number of elements at dim as per indices |
| self_sizes[dim] = indices.size(dim); |
| auto broadcast_shape = infer_size(self_sizes, indices.sizes()); |
| auto indices_broadcasted = at::broadcast_to(indices, broadcast_shape); |
| |
| DimVector indices_sizes{indices.sizes()}; |
| // update number of elements at dim as per self |
| indices_sizes[dim] = self.size(dim); |
| broadcast_shape = infer_size(indices_sizes, self.sizes()); |
| auto self_broadcasted = at::broadcast_to(self, broadcast_shape); |
| |
| return std::make_tuple(self_broadcasted, indices_broadcasted, dim); |
| } |
| |
| static inline void checkDevice(CheckedFrom c, const Tensor& t, Device device) { |
| TORCH_CHECK( |
| !t.defined() || t.device() == device, |
| "Expected tensor to have ", device, |
| " Device, but got tensor with ", t.device(), " Device ", |
| "(while checking arguments for ", c, ")"); |
| } |
| |
| static inline void checkDevice(CheckedFrom c, at::ArrayRef<Tensor> tensors, Device device) { |
| for (auto &t : tensors) { |
| checkDevice(c, t, device); |
| } |
| } |
| |
| } // anonymous namespace |
| |
| Tensor take_along_dim(const Tensor& self, const Tensor& indices, c10::optional<int64_t> opt_dim) { |
| checkDevice("torch.take_along_dim():", {self, indices}, self.device()); |
| if (opt_dim.has_value()) { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int64_t dim; |
| Tensor self_broadcasted, indices_broadcasted; |
| std::tie(self_broadcasted, indices_broadcasted, dim) = |
| _take_along_dim_helper(self, indices, opt_dim.value()); |
| return self_broadcasted.gather(dim, indices_broadcasted); |
| } |
| |
| // similar to `take`, but `take` doesn't support the same dtypes as `gather`. |
| return self.view(-1).gather(0, indices.view(-1)); |
| } |
| |
| Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, c10::optional<int64_t> opt_dim, Tensor& result) { |
| checkDevice("torch.take_along_dim():", {self, indices, result}, self.device()); |
| if (opt_dim.has_value()) { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int64_t dim; |
| Tensor self_broadcasted, indices_broadcasted; |
| std::tie(self_broadcasted, indices_broadcasted, dim) = |
| _take_along_dim_helper(self, indices, opt_dim.value()); |
| return at::gather_out(result, self_broadcasted, dim, indices_broadcasted); |
| } |
| |
| // similar to `take`, but `take` doesn't support the same dtypes as `gather`. |
| return at::gather_out(result, self.view(-1), 0, indices.view(-1)); |
| } |
| |
| Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ |
| // special case scalar input and/or index |
| if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes()); |
| if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes()); |
| Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong)); |
| int64_t grad_numel = grad.numel(); |
| if (grad_numel > 0) { |
| int64_t n_above = grad_numel; |
| int64_t n_below = 1; |
| if (dim < 0) dim += self.ndimension(); |
| for (const auto i : c10::irange(self.ndimension())) { |
| n_above /= grad.size(i); |
| if (i == dim) { |
| sparse_ind[i] = index.reshape(-1); |
| } else { |
| sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below); |
| } |
| n_below *= grad.size(i); |
| } |
| } |
| return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes()); |
| } |
| |
| template <typename scalar_t> |
| int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) { |
| int64_t num_nonzero = 0; |
| |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| constexpr int ilp_factor = 4; |
| const char* ptr = data[0]; |
| const auto stride = strides[0]; |
| int64_t nonzero[ilp_factor] = {0}; |
| |
| int64_t i = 0; |
| for (; i + (ilp_factor - 1) < n; i += ilp_factor) { |
| c10::ForcedUnroll<ilp_factor>{}([&](int k) { |
| const auto& val = c10::load<scalar_t>(ptr + k * stride); |
| if (val != scalar_t(0)) { |
| ++nonzero[k]; |
| } |
| }); |
| ptr += ilp_factor * stride; |
| } |
| for (; i < n; ++i) { |
| const auto& val = c10::load<scalar_t>(ptr); |
| if (val != scalar_t(0)) { |
| ++nonzero[0]; |
| } |
| ptr += stride; |
| } |
| for (const auto k : c10::irange(1, ilp_factor)) { |
| nonzero[0] += nonzero[k]; |
| } |
| num_nonzero += nonzero[0]; |
| }; |
| iter.serial_for_each(loop, range); |
| |
| return num_nonzero; |
| } |
| |
| Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){ |
| return (self != 0).sum(dims); |
| } |
| |
| Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){ |
| if (dims.size() > 0) { |
| return (self != 0).sum(dims); |
| } |
| |
| // Optimized all-reduce |
| auto iter = TensorIteratorConfig() |
| .add_input(self) |
| .build(); |
| |
| const auto num_threads = at::get_num_threads(); |
| DimVector thread_count_nonzero(num_threads); |
| |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] { |
| at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { |
| const auto tid = at::get_thread_num(); |
| thread_count_nonzero[tid] = count_nonzero_impl<scalar_t>(iter, {begin, end}); |
| }); |
| }); |
| |
| for (const auto i : c10::irange(1, num_threads)) { |
| thread_count_nonzero[0] += thread_count_nonzero[i]; |
| } |
| auto out = at::empty({}, self.options().dtype(kLong)); |
| *out.data_ptr<int64_t>() = thread_count_nonzero[0]; |
| return out; |
| } |
| |
| |
| Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim) { |
| if (dim) { |
| return at::count_nonzero(self, IntArrayRef{*dim}); |
| } |
| return at::count_nonzero(self, IntArrayRef{}); |
| } |
| |
| |
| Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { |
| TORCH_CHECK(result.scalar_type() == kLong, |
| "nonzero: Expected out tensor to have scalar type Long " |
| "but got scalar type", result.scalar_type()); |
| at::assert_no_internal_overlap(result); |
| at::assert_no_overlap(result, self); |
| |
| auto iter = TensorIteratorConfig() |
| .add_input(self) |
| .enforce_linear_iteration() |
| .build(); |
| |
| const auto numel = iter.numel(); |
| const auto num_threads = at::get_num_threads(); |
| DimVector thread_begin(num_threads, -1); |
| DimVector thread_count_nonzero(num_threads + 1); |
| |
| // Pass 1: Count nonzero element per-thread |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] { |
| at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { |
| const auto tid = at::get_thread_num(); |
| thread_begin[tid] = begin; |
| thread_count_nonzero[tid + 1] = count_nonzero_impl<scalar_t>(iter, {begin, end}); |
| }); |
| }); |
| |
| // Convert thread-local counts to cumulative sum |
| for (const auto i : c10::irange(1, thread_count_nonzero.size())) { |
| thread_count_nonzero[i] += thread_count_nonzero[i - 1]; |
| } |
| |
| const auto self_sizes = self.sizes(); |
| const auto total_nonzero = thread_count_nonzero.back(); |
| const int64_t ndim = self_sizes.size(); |
| if (resize_output(result, {total_nonzero, ndim})) { |
| // Default to fortran-contiguous output (see gh-46224) |
| result.as_strided_({total_nonzero, ndim}, {1, total_nonzero}); |
| } |
| |
| if (result.numel() == 0) { |
| return result; |
| } |
| |
| // Pass 2: Write indexes |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] { |
| at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { |
| auto tid = at::get_thread_num(); |
| // Work needs to be distributed the same on both passes |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]); |
| |
| // +1 faster than additional condition check inside loop |
| c10::SmallVector<int64_t, 33> sizes(ndim + 1, -1); |
| std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1); |
| c10::SmallVector<int64_t, 33> current_idx(ndim + 1); |
| if (begin > 0) { |
| auto idx = begin; |
| for (int64_t k = ndim; idx > 0 && k > 0; --k) { |
| current_idx[k] = idx % sizes[k]; |
| idx /= sizes[k]; |
| } |
| } |
| |
| auto out_accessor = result.accessor<int64_t, 2>(); |
| auto out_ptr = out_accessor[thread_count_nonzero[tid]].data(); |
| |
| auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) { |
| // Copy into local variables to improve compiler alias analysis |
| int64_t* C10_RESTRICT local_idx = current_idx.data() + 1; |
| const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1; |
| const auto in_stride = strides[0]; |
| const auto out_stride1 = out_accessor.stride(1); |
| const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1; |
| const auto ndim = out_accessor.size(1); |
| int64_t* out = out_ptr; |
| |
| for (const auto i : c10::irange(n2)) { |
| const char* ptr = data[0] + i * strides[1]; |
| for (const auto j : c10::irange(n1)) { |
| (void)j; //Suppress unused variable warning |
| const auto& val = c10::load<scalar_t>(ptr); |
| // If nonzero, write index |
| if (val != scalar_t(0)) { |
| for (const auto k : c10::irange(ndim)) { |
| *out = local_idx[k]; |
| out += out_stride1; |
| } |
| out += out_stride0; |
| } |
| ptr += in_stride; |
| |
| // Advance current index |
| int64_t k = ndim - 1; |
| ++local_idx[k]; |
| while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) { |
| local_idx[k] = 0; |
| --k; |
| ++local_idx[k]; |
| } |
| } |
| } |
| out_ptr = out; |
| }; |
| iter.serial_for_each(loop, {begin, end}); |
| TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data()); |
| }); |
| }); |
| return result; |
| } |
| |
| Tensor nonzero_cpu(const Tensor& self) { |
| auto result = at::empty({0}, self.options().dtype(kLong)); |
| nonzero_out_cpu(self, result); |
| return result; |
| } |
| |
| std::vector<Tensor> nonzero_numpy(const Tensor& self) { |
| // special case scalar for compatibility with numpy: |
| // |
| // >>> np.array(5).nonzero() |
| // (array([0]),) |
| // >>> np.array(0).nonzero() |
| // (array([], dtype=int64),) |
| |
| if (self.dim() == 0) { |
| return self.unsqueeze(0).nonzero().unbind(1); |
| } |
| |
| return self.nonzero().unbind(1); |
| } |
| |
| Tensor argwhere(const Tensor& self) { |
| return self.nonzero(); |
| } |
| |
| Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) { |
| at::assert_no_internal_overlap(self); |
| TORCH_CHECK( |
| self.scalar_type() == source.scalar_type(), |
| "masked_scatter: expected self and source to have same dtypes but got", |
| self.scalar_type(), |
| " and ", |
| source.scalar_type()); |
| |
| TORCH_CHECK(self.device().type() == at::kCPU, "device type of self (", self.device().type(), ") is not CPU"); |
| TORCH_CHECK(mask.device().type() == at::kCPU, "device type of mask (", mask.device().type(), ") is not CPU"); |
| TORCH_CHECK(source.device().type() == at::kCPU, "device type of source (", source.device().type(), ") is not CPU"); |
| |
| c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_"); |
| |
| if (b_mask->dtype() == ScalarType::Byte) { |
| TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ |
| "please use a mask with dtype torch.bool instead."); |
| } |
| |
| auto src_cont = source.contiguous(); |
| |
| auto iter = TensorIteratorConfig() |
| .set_check_mem_overlap(false) |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| .add_output(self) |
| .add_input(*b_mask) |
| .build(); |
| |
| masked_scatter_stub(iter.device_type(), iter, src_cont); |
| return self; |
| } |
| |
| }} // at::native |