| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/native/sparse/SparseStubs.h> |
| #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h> |
| #include <ATen/native/cpu/Loops.h> |
| #include <ATen/native/TensorIterator.h> |
| #include <ATen/AccumulateType.h> |
| |
| namespace at::native { |
| |
| namespace { |
| |
| template <typename func_t> |
| struct CPUKernelLauncher { |
| static void launch(TensorIteratorBase& iter, const func_t& f) { |
| cpu_kernel(iter, f); |
| } |
| }; |
| |
| struct MulOp { |
| template <typename scalar_t> |
| static scalar_t apply(scalar_t a, scalar_t b) { |
| return a * b; |
| } |
| }; |
| |
| template <> |
| bool MulOp::apply(bool a, bool b) { |
| return a && b; |
| } |
| |
| struct RhsProjOp { |
| template <typename scalar_t> |
| static scalar_t apply(scalar_t a, scalar_t b) { |
| return b; |
| } |
| }; |
| |
| struct LhsProjOp { |
| template <typename scalar_t> |
| static scalar_t apply(scalar_t a, scalar_t b) { |
| return a; |
| } |
| }; |
| |
| template <typename binary_op_t> |
| struct CPUValueSelectionIntersectionKernel { |
| static Tensor apply( |
| const Tensor& lhs_values, |
| const Tensor& lhs_select_idx, |
| const Tensor& rhs_values, |
| const Tensor& rhs_select_idx, |
| const Tensor& intersection_counts, |
| const Tensor& argsort, |
| const bool accumulate_matches) { |
| auto iter = make_value_selection_intersection_iter( |
| lhs_values, |
| lhs_select_idx, |
| rhs_values, |
| rhs_select_idx, |
| intersection_counts); |
| auto res_values = iter.tensor(0); |
| |
| auto lhs_nnz_stride = lhs_values.stride(0); |
| auto rhs_nnz_stride = rhs_values.stride(0); |
| |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, at::ScalarType::ComplexHalf, |
| res_values.scalar_type(), |
| "binary_op_intersection_cpu", [&] { |
| // COO indices are only 64-bit for now. |
| using index_t = int64_t; |
| auto loop = [&](char** data, const int64_t* strides, int64_t n) { |
| auto* ptr_res_values_bytes = data[0]; |
| const auto* ptr_lhs_values_bytes = data[1]; |
| const auto* ptr_lhs_select_idx_bytes = data[2]; |
| const auto* ptr_rhs_values_bytes = data[3]; |
| const auto* ptr_rhs_select_idx_bytes = data[4]; |
| const auto* ptr_intersection_counts_bytes = data[5]; |
| const auto* ptr_argsort = argsort.const_data_ptr<index_t>(); |
| |
| for (int64_t i = 0; i < n; ++i) { |
| // Exctract data |
| auto* ptr_res_values = reinterpret_cast<scalar_t*>(ptr_res_values_bytes); |
| const auto* ptr_lhs_values = reinterpret_cast<const scalar_t*>(ptr_lhs_values_bytes); |
| const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes); |
| const auto* ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes); |
| const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes); |
| const auto count = *reinterpret_cast<const int64_t*>(ptr_intersection_counts_bytes); |
| |
| const auto* ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride; |
| const auto* ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx; |
| |
| using accscalar_t = at::acc_type<scalar_t, /*is_gpu=*/false>; |
| accscalar_t res_values = 0; |
| accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin); |
| accscalar_t rhs_values; |
| index_t rhs_sorted_nnz_idx; |
| const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1); |
| for (int64_t c = 0; c < match_count; ++c) { |
| rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++; |
| rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride)); |
| res_values += binary_op_t::apply(lhs_values, rhs_values); |
| } |
| *ptr_res_values = static_cast<scalar_t>(res_values); |
| |
| // Advance |
| ptr_res_values_bytes += strides[0]; |
| ptr_lhs_values_bytes += strides[1]; |
| ptr_lhs_select_idx_bytes += strides[2]; |
| ptr_rhs_values_bytes += strides[3]; |
| ptr_rhs_select_idx_bytes += strides[4]; |
| ptr_intersection_counts_bytes += strides[5]; |
| } |
| }; |
| iter.for_each(loop, at::internal::GRAIN_SIZE); |
| }); |
| |
| return res_values; |
| } |
| }; |
| |
| using OptTensor = std::optional<Tensor>; |
| |
| void mul_sparse_sparse_out_cpu_kernel( |
| Tensor& result, |
| const Tensor& x, |
| const Tensor& y) { |
| using CPUValueSelectionMulKernel = CPUValueSelectionIntersectionKernel<MulOp>; |
| _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueSelectionMulKernel>( |
| result, x, y |
| ); |
| } |
| |
| void sparse_mask_intersection_out_cpu_kernel( |
| Tensor& result, |
| const Tensor& x, |
| const Tensor& y, |
| const OptTensor& x_hash_opt = c10::nullopt) { |
| using CPUValueRhsProjKernel = CPUValueSelectionIntersectionKernel<RhsProjOp>; |
| _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueRhsProjKernel>( |
| result, x, y, x_hash_opt |
| ); |
| } |
| |
| void sparse_mask_projection_out_cpu_kernel( |
| Tensor& result, |
| const Tensor& x, |
| const Tensor& y, |
| const OptTensor& x_hash_opt, |
| bool accumulate_matches) { |
| using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel<LhsProjOp>; |
| _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueLhsProjKernel>( |
| result, x, y, x_hash_opt, c10::nullopt, accumulate_matches |
| ); |
| } |
| |
| } |
| |
| REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel); |
| REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); |
| REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); |
| REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); |
| REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); |
| |
| REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel); |
| REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); |
| REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); |
| REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); |
| REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); |
| |
| REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel); |
| REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); |
| REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); |
| REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); |
| REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); |
| } |