| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/native/NonEmptyUtils.h> |
| #include <ATen/native/DispatchStub.h> |
| #include <ATen/native/TensorIterator.h> |
| #include <ATen/native/TensorAdvancedIndexing.h> |
| #include <ATen/core/Tensor.h> |
| #include <ATen/Config.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/NumericUtils.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/native/cpu/ReduceUtils.h> |
| #include <ATen/cpu/vec/functional.h> |
| #include <ATen/cpu/vec/vec.h> |
| #include <c10/util/irange.h> |
| #ifdef USE_FBGEMM |
| #include <fbgemm/Utils.h> |
| #endif |
| #include <ATen/OpMathType.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/zeros.h> |
| #endif |
| namespace at::native { |
| |
| namespace { |
| |
| // Implement as functors since lambdas don't get optimized. |
| class ReduceMultiply { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data *= opmath_t(*src_data); |
| } |
| |
| constexpr void operator() (bool * self_data, bool * src_data) const { |
| *self_data = *self_data && *src_data; |
| } |
| }; |
| static ReduceMultiply reduce_multiply; |
| |
| class ReduceAdd { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data += opmath_t(*src_data); |
| } |
| }; |
| static ReduceAdd reduce_add; |
| |
| class ReduceMean { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data += opmath_t(*src_data); |
| } |
| }; |
| static ReduceMean reduce_mean; |
| |
| class ReduceMaximum { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::max(*self_data, opmath_t(*src_data)); |
| } |
| }; |
| static ReduceMaximum reduce_maximum; |
| |
| class ReduceMinimum { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::min(*self_data, opmath_t(*src_data)); |
| } |
| }; |
| static ReduceMinimum reduce_minimum; |
| |
| class TensorAssign { |
| public: |
| template <typename scalar_t> |
| constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const { |
| using opmath_t = at::opmath_type<scalar_t>; |
| *self_data = opmath_t(*src_data); |
| } |
| }; |
| static TensorAssign tensor_assign; |
| |
| template <bool is_scatter_like = true> |
| struct _cpu_scatter_gather_dim_loop { |
| template <typename scalar_t, typename func_t> |
| void operator()( |
| at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride, |
| int64_t* index_data, int64_t index_dim_stride, |
| scalar_t* src_data, int64_t src_dim_stride, |
| int64_t dim, int64_t index_dim_size, |
| int64_t index_upper_bound, |
| func_t& f |
| ) { |
| |
| for (const auto i : c10::irange(index_dim_size)) { |
| int64_t idx_dim = index_data[i * index_dim_stride]; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", index_data[i * index_dim_stride], |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound |
| ); |
| |
| f( |
| self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, |
| src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride |
| ); |
| } |
| } |
| |
| template <typename scalar_t, typename func_t> |
| void operator()( |
| at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride, |
| int64_t* index_data, int64_t index_dim_stride, |
| Scalar value, |
| int64_t dim, int64_t index_dim_size, |
| int64_t index_upper_bound, |
| func_t& f |
| ) { |
| |
| for (const auto i : c10::irange(index_dim_size)) { |
| int64_t idx_dim = index_data[i * index_dim_stride]; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", index_data[i * index_dim_stride], |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound |
| ); |
| auto temp = value.to<scalar_t>(); |
| f( |
| self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp |
| ); |
| } |
| } |
| }; |
| |
| inline void create_acc_buffer(Tensor& buffer, const Tensor& self, bool need_acc) { |
| if (need_acc) { |
| auto acc_type = at::toOpMathType(self.scalar_type()); |
| buffer = at::empty(self.sizes(), self.options().dtype(acc_type)); |
| buffer.copy_(self); |
| } else { |
| buffer = self; |
| } |
| } |
| |
| template <bool is_scatter_like = true> |
| struct cpu_scatter_gather_base_kernel { |
| template <typename func_t> |
| void operator()(const Tensor& self, int64_t dim, |
| const Tensor& index, const Scalar& value, |
| const std::string& method_name, func_t& kernel_func) { |
| |
| Tensor buffer; |
| bool need_acc = isReducedFloatingType(self.scalar_type()); |
| create_acc_buffer(buffer, self, need_acc); |
| |
| auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); |
| auto index_strides = ensure_nonempty_vec(index.strides().vec()); |
| |
| // `dim` is traversed in the kernel, |
| // that is why index.stride(dim) = 0 and index.size(dim) = 1. |
| // Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter |
| // has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n). |
| index_sizes[dim] = 1; |
| index_strides[dim] = 0; |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| .declare_static_shape(index.sizes(), /*squash_dim=*/dim) |
| .add_output(buffer) |
| .add_const_input(index) |
| .build(); |
| |
| auto self_dim_stride = ensure_nonempty_stride(buffer, dim); |
| auto self_dim_size = ensure_nonempty_size(buffer, dim); |
| |
| auto index_dim_stride = ensure_nonempty_stride(index, dim); |
| auto index_dim_size = ensure_nonempty_size(index, dim); |
| |
| auto index_upper_bound = self_dim_size; |
| |
| // since the index dimension is squashed, need to alter the grain size according |
| // to keep equal granularity in parallelism. |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size); |
| |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( |
| ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), |
| "scatter_gather_scalar_cpu", [&] { |
| constexpr auto SELF_ITER_STRIDE_IDX = 0; |
| constexpr auto INDEX_ITER_STRIDE_IDX = 1; |
| using opmath_t = at::opmath_type<scalar_t>; |
| _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; |
| auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; |
| // we change the order of TensorIterator-dim loop |
| // vs dim-TensorIterator loop order depending on |
| // whether dim is the last dimension |
| if (dim== buffer.dim() - 1) { |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| // dim loop is a separate code block |
| // for better performance |
| loop_func.template operator()<scalar_t, func_t>( |
| (opmath_t*)self_data_bytes, self_dim_stride, |
| (int64_t*)index_data_bytes, index_dim_stride, |
| value, dim, index_dim_size, index_upper_bound, |
| kernel_func); |
| |
| self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; |
| index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; |
| } |
| } |
| else { |
| for (const auto i : c10::irange(index_dim_size)) { |
| auto* self_data = self_data_bytes; |
| auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| int64_t idx_dim = *(int64_t*)index_data; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", *(int64_t*)index_data, |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound); |
| |
| auto temp = value.to<scalar_t>(); |
| kernel_func((opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp); |
| |
| self_data += strides[SELF_ITER_STRIDE_IDX]; |
| index_data += strides[INDEX_ITER_STRIDE_IDX]; |
| } |
| } |
| } |
| }; |
| iter.for_each(loop, grain_size); |
| } |
| ); |
| if (need_acc) { |
| self.copy_(buffer); |
| } |
| } |
| |
| template <typename func_t> |
| void operator()(const Tensor& self, int64_t dim, |
| const Tensor& index, const Tensor& src, |
| const std::string& method_name, func_t& kernel_func) { |
| |
| Tensor buffer; |
| bool need_acc = isReducedFloatingType(self.scalar_type()); |
| create_acc_buffer(buffer, self, need_acc); |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| .declare_static_shape(index.sizes(), /*squash_dim=*/dim) |
| .add_output(buffer) |
| .add_const_input(src) |
| .add_const_input(index) |
| .build(); |
| |
| auto self_dim_stride = ensure_nonempty_stride(buffer, dim); |
| auto self_dim_size = ensure_nonempty_size(buffer, dim); |
| |
| auto index_dim_stride = ensure_nonempty_stride(index, dim); |
| auto index_dim_size = ensure_nonempty_size(index, dim); |
| |
| auto src_dim_stride = ensure_nonempty_stride(src, dim); |
| auto src_dim_size = ensure_nonempty_size(src, dim); |
| |
| auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size; |
| |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size); |
| |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( |
| ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1), |
| "scatter_gather_tensor_cpu", [&] { |
| constexpr auto SELF_ITER_STRIDE_IDX = 0; |
| constexpr auto INDEX_ITER_STRIDE_IDX = 2; |
| constexpr auto SRC_ITER_STRIDE_IDX = 1; |
| using opmath_t = at::opmath_type<scalar_t>; |
| _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; |
| auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; |
| auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX]; |
| // we change the order of TensorIterator-dim loop |
| // vs dim-TensorIterator loop order depending on |
| // whether dim is the last dimension |
| if (dim== buffer.dim() - 1) { |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| // dim loop is a separate code block |
| // for better performance |
| loop_func.template operator()<scalar_t, func_t>( |
| (opmath_t*)self_data_bytes, self_dim_stride, |
| (int64_t*)index_data_bytes, index_dim_stride, |
| (scalar_t*)src_data_bytes, src_dim_stride, |
| dim, index_dim_size, index_upper_bound, |
| kernel_func |
| ); |
| |
| self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; |
| index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data_bytes += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| else { |
| for (const auto i : c10::irange(index_dim_size)) { |
| auto* self_data = self_data_bytes; |
| auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); |
| auto* src_data = src_data_bytes; |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| int64_t idx_dim = *(int64_t*)index_data; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", *(int64_t*)index_data, |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound); |
| |
| kernel_func( |
| (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, |
| (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride); |
| |
| self_data += strides[SELF_ITER_STRIDE_IDX]; |
| index_data += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| } |
| }; |
| iter.for_each(loop, grain_size); |
| } |
| ); |
| if (need_acc) { |
| self.copy_(buffer); |
| } |
| } |
| |
| void operator()(const Tensor& self, int64_t dim, |
| const Tensor& index, const Tensor& src, |
| const std::string& method_name, ReduceMean& kernel_func) { |
| |
| Tensor buffer; |
| bool need_acc = isReducedFloatingType(self.scalar_type()); |
| create_acc_buffer(buffer, self, need_acc); |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| .declare_static_shape(index.sizes(), /*squash_dim=*/dim) |
| .add_output(buffer) |
| .add_const_input(src) |
| .add_const_input(index) |
| .build(); |
| |
| auto self_dim_stride = ensure_nonempty_stride(buffer, dim); |
| auto self_dim_size = ensure_nonempty_size(buffer, dim); |
| |
| auto index_dim_stride = ensure_nonempty_stride(index, dim); |
| auto index_dim_size = ensure_nonempty_size(index, dim); |
| |
| auto src_dim_stride = ensure_nonempty_stride(src, dim); |
| auto src_dim_size = ensure_nonempty_size(src, dim); |
| |
| auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size; |
| |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size); |
| |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( |
| ScalarType::Half, ScalarType::BFloat16, iter.dtype(1), |
| "scatter_gather_tensor_cpu_reduce_mean", [&] { |
| constexpr auto SELF_ITER_STRIDE_IDX = 0; |
| constexpr auto INDEX_ITER_STRIDE_IDX = 2; |
| constexpr auto SRC_ITER_STRIDE_IDX = 1; |
| using opmath_t = at::opmath_type<scalar_t>; |
| _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; |
| auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; |
| auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX]; |
| // we change the order of TensorIterator-dim loop |
| // vs dim-TensorIterator loop order depending on |
| // whether dim is the last dimension |
| if (dim== buffer.dim() - 1) { |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| // dim loop is a separate code block |
| // for better performance |
| loop_func.template operator()<scalar_t, ReduceMean>( |
| (opmath_t*)self_data_bytes, self_dim_stride, |
| (int64_t*)index_data_bytes, index_dim_stride, |
| (scalar_t*)src_data_bytes, src_dim_stride, |
| dim, index_dim_size, index_upper_bound, |
| kernel_func |
| ); |
| |
| self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; |
| index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data_bytes += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| else { |
| for (const auto i : c10::irange(index_dim_size)) { |
| auto* self_data = self_data_bytes; |
| auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); |
| auto* src_data = src_data_bytes; |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| int64_t idx_dim = *(int64_t*)index_data; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", *(int64_t*)index_data, |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound); |
| |
| kernel_func( |
| (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, |
| (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride); |
| |
| self_data += strides[SELF_ITER_STRIDE_IDX]; |
| index_data += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| } |
| }; |
| iter.for_each(loop, grain_size); |
| } |
| ); |
| if (need_acc) { |
| self.copy_(buffer); |
| } |
| } |
| |
| void operator()(const Tensor& self, int64_t dim, |
| const Tensor& index, const Tensor& src, |
| const std::string& method_name, ReduceMaximum& kernel_func) { |
| Tensor buffer; |
| bool need_acc = isReducedFloatingType(self.scalar_type()); |
| create_acc_buffer(buffer, self, need_acc); |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| .declare_static_shape(index.sizes(), /*squash_dim=*/dim) |
| .add_output(buffer) |
| .add_const_input(src) |
| .add_const_input(index) |
| .build(); |
| |
| auto self_dim_stride = ensure_nonempty_stride(buffer, dim); |
| auto self_dim_size = ensure_nonempty_size(buffer, dim); |
| |
| auto index_dim_stride = ensure_nonempty_stride(index, dim); |
| auto index_dim_size = ensure_nonempty_size(index, dim); |
| |
| auto src_dim_stride = ensure_nonempty_stride(src, dim); |
| auto src_dim_size = ensure_nonempty_size(src, dim); |
| |
| auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size; |
| |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size); |
| |
| AT_DISPATCH_ALL_TYPES_AND3( |
| ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1), |
| "scatter_gather_tensor_cpu_reduce_amax", [&] { |
| constexpr auto SELF_ITER_STRIDE_IDX = 0; |
| constexpr auto INDEX_ITER_STRIDE_IDX = 2; |
| constexpr auto SRC_ITER_STRIDE_IDX = 1; |
| using opmath_t = at::opmath_type<scalar_t>; |
| _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; |
| auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; |
| auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX]; |
| // we change the order of TensorIterator-dim loop |
| // vs dim-TensorIterator loop order depending on |
| // whether dim is the last dimension |
| if (dim== buffer.dim() - 1) { |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| // dim loop is a separate code block |
| // for better performance |
| loop_func.template operator()<scalar_t, ReduceMaximum>( |
| (opmath_t*)self_data_bytes, self_dim_stride, |
| (int64_t*)index_data_bytes, index_dim_stride, |
| (scalar_t*)src_data_bytes, src_dim_stride, |
| dim, index_dim_size, index_upper_bound, |
| kernel_func |
| ); |
| |
| self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; |
| index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data_bytes += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| else { |
| for (const auto i : c10::irange(index_dim_size)) { |
| auto* self_data = self_data_bytes; |
| auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); |
| auto* src_data = src_data_bytes; |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| int64_t idx_dim = *(int64_t*)index_data; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", *(int64_t*)index_data, |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound); |
| |
| kernel_func( |
| (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, |
| (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride); |
| |
| self_data += strides[SELF_ITER_STRIDE_IDX]; |
| index_data += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| } |
| }; |
| iter.for_each(loop, grain_size); |
| } |
| ); |
| if (need_acc) { |
| self.copy_(buffer); |
| } |
| } |
| |
| void operator()(const Tensor& self, int64_t dim, |
| const Tensor& index, const Tensor& src, |
| const std::string& method_name, ReduceMinimum& kernel_func) { |
| |
| Tensor buffer; |
| bool need_acc = isReducedFloatingType(self.scalar_type()); |
| create_acc_buffer(buffer, self, need_acc); |
| |
| auto iter = TensorIteratorConfig() |
| .check_all_same_dtype(false) |
| .resize_outputs(false) |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| .declare_static_shape(index.sizes(), /*squash_dim=*/dim) |
| .add_output(buffer) |
| .add_const_input(src) |
| .add_const_input(index) |
| .build(); |
| |
| auto self_dim_stride = ensure_nonempty_stride(buffer, dim); |
| auto self_dim_size = ensure_nonempty_size(buffer, dim); |
| |
| auto index_dim_stride = ensure_nonempty_stride(index, dim); |
| auto index_dim_size = ensure_nonempty_size(index, dim); |
| |
| auto src_dim_stride = ensure_nonempty_stride(src, dim); |
| auto src_dim_size = ensure_nonempty_size(src, dim); |
| |
| auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size; |
| |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size); |
| |
| AT_DISPATCH_ALL_TYPES_AND3( |
| ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1), |
| "scatter_gather_tensor_cpu_reduce_amin", [&] { |
| constexpr auto SELF_ITER_STRIDE_IDX = 0; |
| constexpr auto INDEX_ITER_STRIDE_IDX = 2; |
| constexpr auto SRC_ITER_STRIDE_IDX = 1; |
| using opmath_t = at::opmath_type<scalar_t>; |
| _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX]; |
| auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX]; |
| auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX]; |
| // we change the order of TensorIterator-dim loop |
| // vs dim-TensorIterator loop order depending on |
| // whether dim is the last dimension |
| if (dim== buffer.dim() - 1) { |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| // dim loop is a separate code block |
| // for better performance |
| loop_func.template operator()<scalar_t, ReduceMinimum>( |
| (opmath_t*)self_data_bytes, self_dim_stride, |
| (int64_t*)index_data_bytes, index_dim_stride, |
| (scalar_t*)src_data_bytes, src_dim_stride, |
| dim, index_dim_size, index_upper_bound, |
| kernel_func |
| ); |
| |
| self_data_bytes += strides[SELF_ITER_STRIDE_IDX]; |
| index_data_bytes += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data_bytes += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| else { |
| for (const auto i : c10::irange(index_dim_size)) { |
| auto* self_data = self_data_bytes; |
| auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); |
| auto* src_data = src_data_bytes; |
| for (const auto nelem C10_UNUSED : c10::irange(n)) { |
| int64_t idx_dim = *(int64_t*)index_data; |
| // we are not putting idx_dim in the error message because it disables |
| // loop optimization in clang-7 |
| TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound, |
| "index ", *(int64_t*)index_data, |
| " is out of bounds for dimension ", dim, |
| " with size ", index_upper_bound); |
| |
| kernel_func( |
| (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, |
| (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride); |
| |
| self_data += strides[SELF_ITER_STRIDE_IDX]; |
| index_data += strides[INDEX_ITER_STRIDE_IDX]; |
| src_data += strides[SRC_ITER_STRIDE_IDX]; |
| } |
| } |
| } |
| }; |
| iter.for_each(loop, grain_size); |
| } |
| ); |
| if (need_acc) { |
| self.copy_(buffer); |
| } |
| } |
| }; |
| |
| #ifndef USE_FBGEMM |
| namespace fbgemm { |
| |
| template <typename K, typename V> |
| std::pair<K*, V*> radix_sort_parallel( |
| K* const inp_key_buf, |
| V* const inp_value_buf, |
| K* const tmp_key_buf, |
| V* const tmp_value_buf, |
| const int64_t elements_count, |
| const int64_t max_value) { |
| TORCH_INTERNAL_ASSERT(false, "radix_sort_parallel: ATen not compiled with FBGEMM support"); |
| return std::make_pair(nullptr, nullptr); |
| } |
| |
| } |
| #endif |
| |
| // Note [scatter reduce optimization] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // |
| // 1. initiative: optimize `scatter_reduce` on classic PyG use-case: |
| // `scatter_reduce` is extensively used on 'message passing' when |
| // aggregating info. |
| // |
| // Typically, `self` will 2D tensor and `index` is a 1D extended/broadcasted |
| // tensor, which means that the aggregation is on rowwise and we can vectorize |
| // on the inner dimensions. |
| // |
| // 2. implementation: map `scatter_reduce` to `spmm` reduce |
| // in the shape of `[M, N]` * `[N, K]`, where: |
| // |
| // M: self_dim_size |
| // nnz: index_dim_size |
| // K: index.numel() / index_dim_size; |
| // |
| // step 1: convert input index to CSR format (use radix_sort to |
| // solve write addr conflicts on `self` tensor) |
| // |
| // step 2: spmm reduce, parallel on M and vectorize on K |
| // |
| |
| template <typename scalar_t, ReductionType reduce> |
| void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) { |
| const int64_t* index_data = index.const_data_ptr<int64_t>(); |
| scalar_t* self_data = self.data_ptr<scalar_t>(); |
| const scalar_t* src_data = src.const_data_ptr<scalar_t>(); |
| |
| const int64_t M = ensure_nonempty_size(self, 0); |
| const int64_t nnz = ensure_nonempty_size(index, 0); |
| const int64_t K = index.numel() / nnz; |
| |
| const int64_t index_upper_bound = M; |
| |
| auto keys = std::make_unique<int64_t[]>(nnz); |
| auto values = std::make_unique<int64_t[]>(nnz); |
| auto keys_tmp = std::make_unique<int64_t[]>(nnz); |
| auto values_tmp = std::make_unique<int64_t[]>(nnz); |
| at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) { |
| for (const auto i : c10::irange(begin, end)) { |
| int64_t index = index_data[i]; |
| TORCH_CHECK(index >= 0 && index < index_upper_bound, |
| "index ", index, |
| " is out of bounds for dimension ", 0, |
| " with size ", index_upper_bound); |
| keys[i] = index; |
| values[i] = i; |
| } |
| }); |
| |
| int64_t* sorted_col_index_keys = nullptr; |
| int64_t* sorted_col_index_values = nullptr; |
| std::tie(sorted_col_index_keys, sorted_col_index_values) = fbgemm::radix_sort_parallel( |
| keys.get(), |
| values.get(), |
| keys_tmp.get(), |
| values_tmp.get(), |
| nnz, |
| M); |
| |
| int num_threads = at::get_num_threads(); |
| std::vector<int64_t> num_uniq(num_threads, 0); |
| at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) { |
| int tid = at::get_thread_num(); |
| for(const auto i : c10::irange(begin, end)) { |
| if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) { |
| num_uniq[tid]++; |
| } |
| } |
| }); |
| num_uniq[0]++; |
| for (const auto n : c10::irange(1, num_threads)) { |
| num_uniq[n] += num_uniq[n - 1]; |
| } |
| |
| // in case some rows are not written into, num_nonzero_rows will be smaller than M |
| int64_t num_nonzero_rows = num_uniq[num_threads - 1]; |
| auto row_index_tmp = std::make_unique<int64_t[]>(num_nonzero_rows); |
| auto row_index_offset_tmp = std::make_unique<int64_t[]>(num_nonzero_rows + 1); |
| int64_t* row_index = row_index_tmp.get(); |
| int64_t* row_index_offset = row_index_offset_tmp.get(); |
| row_index[0] = sorted_col_index_keys[0]; |
| row_index_offset[0] = 0; |
| row_index_offset[num_nonzero_rows] = nnz; |
| |
| at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) { |
| int tid = at::get_thread_num(); |
| int64_t* t_index = row_index + ((tid == 0) ? 1 : num_uniq[tid - 1]); |
| int64_t* t_index_offset = row_index_offset + ((tid == 0) ? 1 : num_uniq[tid - 1]); |
| for (const auto i : c10::irange(begin, end)) { |
| if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) { |
| *t_index = sorted_col_index_keys[i]; |
| *t_index_offset = i; |
| t_index++; |
| t_index_offset++; |
| } |
| } |
| }); |
| |
| using opmath_t = at::opmath_type<scalar_t>; |
| Tensor buffer; |
| opmath_t* buffer_data = nullptr; |
| static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>; |
| if constexpr (need_acc) { |
| auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true); |
| buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type)); |
| buffer_data = buffer.data_ptr<opmath_t>(); |
| } |
| |
| // TODO: do blocking on col dimension to reduce WR bandwidth |
| at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) { |
| int tid = at::get_thread_num(); |
| TORCH_CHECK(tid < num_threads, |
| "expect thread id smaller than ", num_threads, ", got thread id ", tid); |
| opmath_t* buffer_ptr = nullptr; |
| |
| for (const auto m : c10::irange(begin, end)) { |
| int64_t row = row_index[m]; |
| int64_t off_start = row_index_offset[m]; |
| int64_t off_end = row_index_offset[m + 1]; |
| scalar_t* self_ptr = self_data + row * K; |
| if constexpr (need_acc) { |
| buffer_ptr = buffer_data + tid * K; |
| } else { |
| buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr); |
| } |
| |
| // step 1: reinit rows in `self` if needed |
| _init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self); |
| |
| // step 2: reduce |
| for (const auto n : c10::irange(off_start, off_end)) { |
| int64_t col = sorted_col_index_values[n]; |
| update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K); |
| } |
| if constexpr (need_acc) { |
| vec::convert(buffer_ptr, self_ptr, K); |
| } |
| |
| // step 3: finalize |
| int64_t count = include_self ? 1 : 0; |
| count += off_end - off_start; |
| write<scalar_t, reduce>(self_ptr, count, K); |
| } |
| }); |
| } |
| |
| template <typename scalar_t> |
| void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) { |
| const int64_t* index_data = index.const_data_ptr<int64_t>(); |
| scalar_t* result_data = result.data_ptr<scalar_t>(); |
| const scalar_t* self_data = self.const_data_ptr<scalar_t>(); |
| |
| const int64_t M = ensure_nonempty_size(result, 0); |
| const int64_t N = ensure_nonempty_size(self, 0); |
| const int64_t K = index.numel() / M; |
| |
| const int64_t index_upper_bound = N; |
| |
| using Vec = vec::Vectorized<scalar_t>; |
| int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / K); |
| at::parallel_for(0, M, grain_size, [&](int64_t begin, int64_t end) { |
| for (const auto m : c10::irange(begin, end)) { |
| scalar_t* result_ptr = result_data + m * K; |
| int64_t index = index_data[m]; |
| TORCH_CHECK(index >= 0 && index < index_upper_bound, |
| "index ", index, |
| " is out of bounds for dimension ", 0, |
| " with size ", index_upper_bound); |
| const scalar_t* self_ptr = self_data + index * K; |
| int64_t d = 0; |
| for (; d < K - (K % Vec::size()); d += Vec::size()) { |
| Vec out_vec = Vec::loadu(self_ptr + d); |
| out_vec.store(result_ptr + d); |
| } |
| #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) |
| # pragma unroll |
| #endif |
| for (; d < K; d++) { |
| result_ptr[d] = self_ptr[d]; |
| } |
| } |
| }); |
| } |
| |
| void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) { |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_add_expanded_index", [&] { |
| cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true); |
| }); |
| } |
| |
| void scatter_reduce_expanded_index_kernel( |
| const Tensor& self, const Tensor& index, const Tensor& src, |
| const ReductionType& reduction, bool include_self) { |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_reduce_expanded_index", [&] { |
| AT_DISPATCH_REDUCTION_TYPES(reduction, [&]() { |
| cpu_scatter_reduce_expanded_index<scalar_t, reduce>(self, index, src, include_self); |
| }); |
| }); |
| } |
| |
| void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) { |
| AT_DISPATCH_FLOATING_TYPES_AND( |
| ScalarType::BFloat16, self.scalar_type(), "gather_expanded_index", [&] { |
| cpu_gather_expanded_index_kernel<scalar_t>(result, index, self); |
| }); |
| } |
| |
| void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { |
| cpu_scatter_gather_base_kernel</*is_scatter_like=*/false>()( |
| result, dim, index, self, |
| "gather_out_cpu", tensor_assign); |
| } |
| |
| void scatter_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { |
| cpu_scatter_gather_base_kernel<>()( |
| self, dim, index, src, "scatter_cpu_", tensor_assign); |
| } |
| |
| void scatter_fill_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) { |
| cpu_scatter_gather_base_kernel<>()( |
| self, dim, index, value, "scatter_fill_cpu_", tensor_assign); |
| } |
| |
| void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { |
| cpu_scatter_gather_base_kernel<>()( |
| self, dim, index, src, |
| "scatter_add_", reduce_add); |
| } |
| |
| void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index, |
| const Tensor& src, const ReductionType& reduce) { |
| switch (reduce) { |
| case ReductionType::SUM : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_add_", reduce_add); |
| break; |
| case ReductionType::PROD : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_multiply_", reduce_multiply); |
| break; |
| default : |
| break; |
| } |
| } |
| |
| void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index, |
| const Tensor& src, const ReductionType& reduce) { |
| switch (reduce) { |
| case ReductionType::SUM : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_sum_", reduce_add); |
| break; |
| case ReductionType::PROD : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_prod_", reduce_multiply); |
| break; |
| case ReductionType::MAX : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_amax_", reduce_maximum); |
| break; |
| case ReductionType::MIN : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_amin_", reduce_minimum); |
| break; |
| case ReductionType::MEAN : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, src, |
| "scatter_reduce_mean_", reduce_mean); |
| break; |
| } |
| } |
| |
| void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index, |
| const Scalar& value, const ReductionType& reduce) { |
| switch (reduce) { |
| case ReductionType::SUM : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, value, |
| "scatter_scalar_reduce_add_", reduce_add); |
| break; |
| case ReductionType::PROD : |
| cpu_scatter_gather_base_kernel<>()(self, dim, index, value, |
| "scatter_scalar_reduce_multiply_", reduce_multiply); |
| break; |
| default: |
| break; |
| } |
| } |
| |
| } // anonymous namespace |
| |
| REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel); |
| REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel); |
| REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel); |
| REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel); |
| REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cpu_kernel); |
| REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cpu_kernel); |
| REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cpu_kernel); |
| |
| // fast paths for GNN usage |
| REGISTER_DISPATCH(scatter_add_expanded_index_stub, &scatter_add_expanded_index_kernel); |
| REGISTER_DISPATCH(scatter_reduce_expanded_index_stub, &scatter_reduce_expanded_index_kernel); |
| REGISTER_DISPATCH(gather_expanded_index_stub, &gather_expanded_index_kernel); |
| |
| } // namespace at::native |