sparse.mm backward: performance improvements (#94991)
`torch.sparse.mm` - faster and without syncs in "most" cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94991
Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
index 7888ac6a..67e28d8 100644
--- a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
+++ b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
@@ -36,6 +36,13 @@
}
};
+struct LhsProjOp {
+ template <typename scalar_t>
+ static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
+ return a;
+ }
+};
+
template <int nt, int vt, typename loop_t>
C10_LAUNCH_BOUNDS_2(nt, vt)
__global__ void apply_kernel(int n, loop_t loop) {
@@ -70,11 +77,12 @@
TensorIterator& iter,
int64_t lhs_nnz_stride,
int64_t rhs_nnz_stride,
- const Tensor& argsort) {
+ const Tensor& argsort,
+ const bool accumulate_matches) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
- sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
+ sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
}
return;
}
@@ -106,7 +114,8 @@
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
- for (int64_t c = 0; c < count; ++c) {
+ const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
+ for (int64_t c = 0; c < match_count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
@@ -126,7 +135,8 @@
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts,
- const Tensor& argsort) {
+ const Tensor& argsort,
+ const bool accumulate_matches) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
@@ -150,7 +160,7 @@
// COO indices are only 64-bit for now.
using index_t = int64_t;
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
- iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
+ iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
});
return res_values;
@@ -180,9 +190,21 @@
);
}
+void sparse_mask_projection_out_cuda_kernel(
+ Tensor& result,
+ const Tensor& x,
+ const Tensor& y,
+ const OptTensor& x_hash_opt = c10::nullopt) {
+ using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel<LhsProjOp>;
+ _sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueLhsProjKernel>(
+ result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
+ );
+}
+
}
REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
+REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);
} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2e04e24..69968cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6801,6 +6801,12 @@
SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_csr
autogen: sparse_mask.out
+- func: _sparse_mask_projection(Tensor self, Tensor mask) -> Tensor
+ variants: method
+ dispatch:
+ SparseCPU, SparseCUDA: sparse_mask_projection
+ autogen: _sparse_mask_projection.out
+
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
variants: function
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
index 0e1b96f..94faadf 100644
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
@@ -136,6 +136,7 @@
const std::vector<int64_t> broadcasted_shape,
const c10::optional<Tensor>& x_hash_opt_ = c10::nullopt,
const c10::optional<Tensor>& y_hash_opt_ = c10::nullopt,
+ const bool accumulate_matches = true,
const bool distributive_with_sum = true
) {
// The common dtype check is relevant when op is done in-place.
@@ -403,7 +404,8 @@
probably_coalesced._values().to(binary_op_res_dtype),
intersection_first_idx.to(nnz_arange.scalar_type()),
intersection_count,
- argsort_hash).to(res.scalar_type());
+ argsort_hash,
+ accumulate_matches).to(res.scalar_type());
const auto res_sparse_dim = source.sparse_dim();
const auto res_dense_dim = source.dense_dim();
const auto& res_shape = broadcasted_shape;
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
index 211d7f6..9d1f349 100644
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
@@ -35,6 +35,13 @@
}
};
+struct LhsProjOp {
+ template <typename scalar_t>
+ static scalar_t apply(scalar_t a, scalar_t b) {
+ return a;
+ }
+};
+
template <typename binary_op_t>
struct CPUValueSelectionIntersectionKernel {
static Tensor apply(
@@ -43,7 +50,8 @@
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts,
- const Tensor& argsort) {
+ const Tensor& argsort,
+ const bool accumulate_matches) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
@@ -86,7 +94,8 @@
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
- for (int64_t c = 0; c < count; ++c) {
+ const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
+ for (int64_t c = 0; c < match_count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
@@ -132,6 +141,17 @@
);
}
+void sparse_mask_projection_out_cpu_kernel(
+ Tensor& result,
+ const Tensor& x,
+ const Tensor& y,
+ const OptTensor& x_hash_opt = c10::nullopt) {
+ using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel<LhsProjOp>;
+ _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueLhsProjKernel>(
+ result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
+ );
+}
+
}
REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
@@ -145,4 +165,10 @@
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
-} // namespace at::native
+
+REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+}
diff --git a/aten/src/ATen/native/sparse/SparseStubs.h b/aten/src/ATen/native/sparse/SparseStubs.h
index 7782043..0f71fa2 100644
--- a/aten/src/ATen/native/sparse/SparseStubs.h
+++ b/aten/src/ATen/native/sparse/SparseStubs.h
@@ -16,6 +16,9 @@
using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub);
+using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
+DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub);
+
using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size);
DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub);
diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp
index e446e9b..d54f19f 100644
--- a/aten/src/ATen/native/sparse/SparseTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensor.cpp
@@ -46,6 +46,7 @@
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like_native.h>
#include <ATen/ops/empty_native.h>
+#include <ATen/ops/zeros_like.h>
#include <ATen/ops/index_select.h>
#include <ATen/ops/indices_native.h>
#include <ATen/ops/is_coalesced_native.h>
@@ -55,6 +56,7 @@
#include <ATen/ops/sparse_coo_tensor_native.h>
#include <ATen/ops/sparse_dim_native.h>
#include <ATen/ops/sparse_mask_native.h>
+#include <ATen/ops/_sparse_mask_projection_native.h>
#include <ATen/ops/sparse_resize_and_clear_native.h>
#include <ATen/ops/sparse_resize_native.h>
#include <ATen/ops/to_dense_native.h>
@@ -744,6 +746,72 @@
}
DEFINE_DISPATCH(sparse_mask_intersection_out_stub);
+DEFINE_DISPATCH(sparse_mask_projection_out_stub);
+
+using OptTensor = c10::optional<Tensor>;
+
+std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
+ const std::string& method_name,
+ const Tensor& t,
+ const Tensor& mask) {
+ // This is a helper function for operations that implement "sparse_mask"-like
+ // functionality, namely, projection of values of one tensor onto the other.
+ // These operations mostly rely on COO intersection primitives that heavily
+ // exploit coalesced inputs to avoid any syncs and calls to sort. The problem
+ // is that these primitives might project first argument onto second one or
+ // the other way around depending on which arguments are coalesced and which are
+ // larger. This function prepares inputs for `sparse_mask` such that `t` is
+ // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
+ // as coalesced all while `mask` is set to uncoalesced.
+ // The result of this projectionk is going to be uncoalesced, so it is up to the
+ // user to set the corresponding flag correctly with respect to the operations'
+ // semantics.
+
+ // We already assume that t.sizes() == mask.sizes()
+ TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
+ method_name, "(): the number of sparse dimensions in `self` ",
+ "should match that of the `mask`. ",
+ "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
+ "`mask.sparse_dim() == ", mask.sparse_dim(), "`.");
+
+ const auto wrapped_tensor = [](const Tensor& t,
+ const OptTensor& indices = c10::nullopt,
+ const OptTensor& values = c10::nullopt) -> Tensor {
+ auto res = at::empty({0}, t.options());
+ auto* res_sparse_impl = get_sparse_impl(res);
+ res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
+ const auto res_indices = indices.has_value() ? *indices : t._indices();
+ const auto res_values = values.has_value() ? *values : t._values();
+ res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
+ res_sparse_impl->set_nnz_and_narrow(t._nnz());
+ res._coalesced_(false);
+ return res;
+ };
+
+ Tensor lhs;
+ OptTensor lhs_hash_opt;
+
+ std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
+ if (t.is_coalesced()) {
+ return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
+ } else {
+ const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
+ const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
+ // Probably worth having a dedicated kernel for.
+ const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
+ const auto res_values = t._values().index_select(0, argsort_indices_hash);
+ const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
+ // NOTE: res is not necessariy coalesced, but it is sorted.
+ // We mark it as "coalesced" to skip sorting in the intersection kernel.
+ auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
+ return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
+ }
+ }();
+
+ const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
+
+ return std::make_tuple(lhs, rhs, lhs_hash_opt);
+}
SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
TORCH_CHECK(
@@ -753,57 +821,25 @@
" but mask has size ",
mask.sizes());
- if (!mask.numel()) {
+ if (t.is_same(mask)) {
+ return t;
+ }
+
+ if (!mask.numel() || !mask._nnz()) {
return mask.clone().to(t.device(), t.scalar_type());
}
if (t.layout() == at::kSparse) {
- TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
- "sparse_mask(): the number of sparse dimensions in `self` ",
- "should match that of the `mask`. ",
- "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
- "`mask.sparse_dim() == ", mask.sparse_dim(), "`.");
-
- using OptTensor = c10::optional<Tensor>;
-
- const auto wrapped_tensor = [](const Tensor& t,
- const OptTensor& indices = c10::nullopt,
- const OptTensor& values = c10::nullopt) -> Tensor {
- auto res = at::empty({0}, t.options());
- auto* res_sparse_impl = get_sparse_impl(res);
- res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
- const auto res_indices = indices.has_value() ? *indices : t._indices();
- const auto res_values = values.has_value() ? *values : t._values();
- res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
- res_sparse_impl->set_nnz_and_narrow(t._nnz());
- res._coalesced_(false);
+ if (!t._nnz()) {
+ auto res = mask.clone().to(t.device(), t.scalar_type());
+ res._values().zero_();
return res;
- };
-
- using OptTensor = c10::optional<Tensor>;
- Tensor lhs;
- OptTensor lhs_hash_opt;
-
- std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
- if (t.is_coalesced()) {
- return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
- } else {
- const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
- const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
- // Probably worth having a dedicated kernel for.
- const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
- const auto res_values = t._values().index_select(0, argsort_indices_hash);
- const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
- // NOTE: res is not necessariy coalesced, but it is sorted.
- // We mark it as "coalesced" to skip sorting in the intersection kernel.
- auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
- return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
- }
- }();
-
- const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
+ }
auto res = at::empty({0}, t.options());
+ Tensor lhs, rhs;
+ OptTensor lhs_hash_opt;
+ std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("sparse_mask", t, mask);
sparse_mask_intersection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
return res._coalesced_(mask.is_coalesced());
}
@@ -816,6 +852,31 @@
return t.mul(mask_template).to(t.scalar_type());
}
+Tensor sparse_mask_projection(const Tensor& t, const Tensor& mask) {
+ TORCH_INTERNAL_ASSERT(t.is_sparse());
+ TORCH_INTERNAL_ASSERT(mask.is_sparse());
+
+ TORCH_CHECK(
+ mask.sizes().equals(t.sizes()),
+ "_sparse_mask_projection(): operands have incompatible sizes; self has size ",
+ t.sizes(),
+ " but mask has size ",
+ mask.sizes());
+
+ if (!t.numel() || !t._nnz() || !mask._nnz()) {
+ auto res = t.clone();
+ res._values().zero_();
+ return res;
+ }
+
+ auto res = at::empty({0}, t.options());
+ Tensor lhs, rhs;
+ OptTensor lhs_hash_opt;
+ std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("_sparse_mask_projection", mask, t);
+ sparse_mask_projection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
+ return res._coalesced_(t.is_coalesced());
+}
+
Tensor empty_like_sparse_coo(
const Tensor& self,
c10::optional<ScalarType> dtype,
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index f914bab..2e12685 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -465,6 +465,8 @@
aten::_sparse_log_softmax.out
aten::_sparse_log_softmax_backward_data
aten::_sparse_log_softmax_backward_data.out
+aten::_sparse_mask_projection
+aten::_sparse_mask_projection.out
aten::_sparse_mm_reduce_impl
aten::_sparse_mm_reduce_impl_backward
aten::_sparse_softmax
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 1cd06e7..342ab86 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -1451,6 +1451,30 @@
mat2.layout());
}
+Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) {
+ if (x.is_coalesced() && gx.is_coalesced()) {
+ if (x._nnz() >= gx._nnz()) {
+ // search into x is faster
+ return gx._sparse_mask_projection(x);
+ } else {
+ // search into gx is faster
+ return gx.sparse_mask(x);
+ }
+ } else if (x.is_coalesced()) {
+ return gx.sparse_mask(x);
+ } else if (gx.is_coalesced()) {
+ return gx._sparse_mask_projection(x);
+ } else {
+ if (x._nnz() >= gx._nnz()) {
+ // gx.coalesce() is likely faster
+ return gx.coalesce()._sparse_mask_projection(x);
+ } else {
+ // x.coalesce() is likely faster
+ return gx.sparse_mask(x.coalesce());
+ }
+ }
+}
+
Tensor sparse_sparse_matmul_backward(
const Tensor& grad,
const Tensor& a,
@@ -1475,19 +1499,13 @@
TORCH_CHECK(
grad_order == 0 || grad_order == 1,
": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
- const auto mask_ones_like = [](const Tensor& t) -> Tensor {
- return at::sparse_coo_tensor(
- t._indices(),
- at::ones({1}, t._values().options()).expand_as(t._values()),
- t.sizes());
- };
if (grad_order == 0) {
auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
- return a_grad.mul(mask_ones_like(a.coalesce()));
+ return sparse_mask_like_grad(a, a_grad);
}
auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
- return b_grad.mul(mask_ones_like(b.coalesce()));
+ return sparse_mask_like_grad(b, b_grad);
}
Tensor renorm_backward(
diff --git a/torch/overrides.py b/torch/overrides.py
index 11989e7..c51b35e 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -1331,6 +1331,7 @@
Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
Tensor.sparse_dim: lambda self: -1,
Tensor.sparse_mask: lambda self, mask: -1,
+ Tensor._sparse_mask_projection: lambda self, mask: -1,
Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py
index b057634..eb91a49 100644
--- a/torchgen/static_runtime/generator.py
+++ b/torchgen/static_runtime/generator.py
@@ -126,6 +126,7 @@
"zero",
"_sparse_addmm",
"sparse_mask",
+ "_sparse_mask_projection",
"_to_dense",
"_coalesce",
"_coalesced",