sparse.mm backward: performance improvements (#94991)

`torch.sparse.mm` - faster and without syncs in "most" cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94991
Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
index 7888ac6a..67e28d8 100644
--- a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
+++ b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
@@ -36,6 +36,13 @@
   }
 };
 
+struct LhsProjOp {
+  template <typename scalar_t>
+  static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
+    return a;
+  }
+};
+
 template <int nt, int vt, typename loop_t>
 C10_LAUNCH_BOUNDS_2(nt, vt)
 __global__ void apply_kernel(int n, loop_t loop) {
@@ -70,11 +77,12 @@
     TensorIterator& iter,
     int64_t lhs_nnz_stride,
     int64_t rhs_nnz_stride,
-    const Tensor& argsort) {
+    const Tensor& argsort,
+    const bool accumulate_matches) {
   if (!iter.can_use_32bit_indexing()) {
     for (auto& sub_iter : iter.with_32bit_indexing()) {
       binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
-          sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
+          sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
     }
     return;
   }
@@ -106,7 +114,8 @@
     accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
     accscalar_t rhs_values;
     index_t rhs_sorted_nnz_idx;
-    for (int64_t c = 0; c < count; ++c) {
+    const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
+    for (int64_t c = 0; c < match_count; ++c) {
       rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
       rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
       res_values += binary_op_t::apply(lhs_values, rhs_values);
@@ -126,7 +135,8 @@
       const Tensor& rhs_values,
       const Tensor& rhs_select_idx,
       const Tensor& intersection_counts,
-      const Tensor& argsort) {
+      const Tensor& argsort,
+      const bool accumulate_matches) {
     auto iter = make_value_selection_intersection_iter(
         lhs_values,
         lhs_select_idx,
@@ -150,7 +160,7 @@
           // COO indices are only 64-bit for now.
           using index_t = int64_t;
           binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
-              iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
+              iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
         });
 
     return res_values;
@@ -180,9 +190,21 @@
   );
 }
 
+void sparse_mask_projection_out_cuda_kernel(
+    Tensor& result,
+    const Tensor& x,
+    const Tensor& y,
+    const OptTensor& x_hash_opt = c10::nullopt) {
+  using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel<LhsProjOp>;
+  _sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueLhsProjKernel>(
+      result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
+  );
+}
+
 }
 
 REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
 REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
+REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);
 
 } // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2e04e24..69968cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6801,6 +6801,12 @@
     SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_csr
   autogen: sparse_mask.out
 
+- func: _sparse_mask_projection(Tensor self, Tensor mask) -> Tensor
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_mask_projection
+  autogen: _sparse_mask_projection.out
+
 - func: _to_cpu(Tensor[] tensors) -> Tensor[]
   variants: function
 
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
index 0e1b96f..94faadf 100644
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
@@ -136,6 +136,7 @@
     const std::vector<int64_t> broadcasted_shape,
     const c10::optional<Tensor>& x_hash_opt_ = c10::nullopt,
     const c10::optional<Tensor>& y_hash_opt_ = c10::nullopt,
+    const bool accumulate_matches = true,
     const bool distributive_with_sum = true
 ) {
   // The common dtype check is relevant when op is done in-place.
@@ -403,7 +404,8 @@
       probably_coalesced._values().to(binary_op_res_dtype),
       intersection_first_idx.to(nnz_arange.scalar_type()),
       intersection_count,
-      argsort_hash).to(res.scalar_type());
+      argsort_hash,
+      accumulate_matches).to(res.scalar_type());
   const auto res_sparse_dim = source.sparse_dim();
   const auto res_dense_dim = source.dense_dim();
   const auto& res_shape = broadcasted_shape;
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
index 211d7f6..9d1f349 100644
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
@@ -35,6 +35,13 @@
   }
 };
 
+struct LhsProjOp {
+  template <typename scalar_t>
+  static scalar_t apply(scalar_t a, scalar_t b) {
+    return a;
+  }
+};
+
 template <typename binary_op_t>
 struct CPUValueSelectionIntersectionKernel {
   static Tensor apply(
@@ -43,7 +50,8 @@
       const Tensor& rhs_values,
       const Tensor& rhs_select_idx,
       const Tensor& intersection_counts,
-      const Tensor& argsort) {
+      const Tensor& argsort,
+      const bool accumulate_matches) {
     auto iter = make_value_selection_intersection_iter(
         lhs_values,
         lhs_select_idx,
@@ -86,7 +94,8 @@
                 accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
                 accscalar_t rhs_values;
                 index_t rhs_sorted_nnz_idx;
-                for (int64_t c = 0; c < count; ++c) {
+                const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
+                for (int64_t c = 0; c < match_count; ++c) {
                   rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
                   rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
                   res_values += binary_op_t::apply(lhs_values, rhs_values);
@@ -132,6 +141,17 @@
   );
 }
 
+void sparse_mask_projection_out_cpu_kernel(
+    Tensor& result,
+    const Tensor& x,
+    const Tensor& y,
+    const OptTensor& x_hash_opt = c10::nullopt) {
+  using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel<LhsProjOp>;
+  _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueLhsProjKernel>(
+      result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
+  );
+}
+
 }
 
 REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
@@ -145,4 +165,10 @@
 REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
 REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
 REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
-} // namespace at::native
+
+REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
+}
diff --git a/aten/src/ATen/native/sparse/SparseStubs.h b/aten/src/ATen/native/sparse/SparseStubs.h
index 7782043..0f71fa2 100644
--- a/aten/src/ATen/native/sparse/SparseStubs.h
+++ b/aten/src/ATen/native/sparse/SparseStubs.h
@@ -16,6 +16,9 @@
 using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
 DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub);
 
+using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
+DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub);
+
 using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size);
 DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub);
 
diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp
index e446e9b..d54f19f 100644
--- a/aten/src/ATen/native/sparse/SparseTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensor.cpp
@@ -46,6 +46,7 @@
 #include <ATen/ops/empty.h>
 #include <ATen/ops/empty_like_native.h>
 #include <ATen/ops/empty_native.h>
+#include <ATen/ops/zeros_like.h>
 #include <ATen/ops/index_select.h>
 #include <ATen/ops/indices_native.h>
 #include <ATen/ops/is_coalesced_native.h>
@@ -55,6 +56,7 @@
 #include <ATen/ops/sparse_coo_tensor_native.h>
 #include <ATen/ops/sparse_dim_native.h>
 #include <ATen/ops/sparse_mask_native.h>
+#include <ATen/ops/_sparse_mask_projection_native.h>
 #include <ATen/ops/sparse_resize_and_clear_native.h>
 #include <ATen/ops/sparse_resize_native.h>
 #include <ATen/ops/to_dense_native.h>
@@ -744,6 +746,72 @@
 }
 
 DEFINE_DISPATCH(sparse_mask_intersection_out_stub);
+DEFINE_DISPATCH(sparse_mask_projection_out_stub);
+
+using OptTensor = c10::optional<Tensor>;
+
+std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
+    const std::string& method_name,
+    const Tensor& t,
+    const Tensor& mask) {
+  // This is a helper function for operations that implement "sparse_mask"-like
+  // functionality, namely, projection of values of one tensor onto the other.
+  // These operations mostly rely on COO intersection primitives that heavily
+  // exploit coalesced inputs to avoid any syncs and calls to sort. The problem
+  // is that these primitives might project first argument onto second one or
+  // the other way around depending on which arguments are coalesced and which are
+  // larger. This function prepares inputs for `sparse_mask` such that `t` is
+  // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
+  // as coalesced all while `mask` is set to uncoalesced.
+  // The result of this projectionk is going to be uncoalesced, so it is up to the
+  // user to set the corresponding flag correctly with respect to the operations'
+  // semantics.
+
+  // We already assume that t.sizes() == mask.sizes()
+  TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
+              method_name, "(): the number of sparse dimensions in `self` ",
+              "should match that of the `mask`. ",
+              "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
+              "`mask.sparse_dim() == ", mask.sparse_dim(), "`.");
+
+  const auto wrapped_tensor = [](const Tensor& t,
+                                 const OptTensor& indices = c10::nullopt,
+                                 const OptTensor& values = c10::nullopt) -> Tensor {
+    auto res = at::empty({0}, t.options());
+    auto* res_sparse_impl = get_sparse_impl(res);
+    res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
+    const auto res_indices = indices.has_value() ? *indices : t._indices();
+    const auto res_values = values.has_value() ? *values : t._values();
+    res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
+    res_sparse_impl->set_nnz_and_narrow(t._nnz());
+    res._coalesced_(false);
+    return res;
+  };
+
+  Tensor lhs;
+  OptTensor lhs_hash_opt;
+
+  std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
+    if (t.is_coalesced()) {
+      return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
+    } else {
+      const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
+      const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
+      // Probably worth having a dedicated kernel for.
+      const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
+      const auto res_values = t._values().index_select(0, argsort_indices_hash);
+      const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
+      // NOTE: res is not necessariy coalesced, but it is sorted.
+      // We mark it as "coalesced" to skip sorting in the intersection kernel.
+      auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
+      return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
+    }
+  }();
+
+  const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
+
+  return std::make_tuple(lhs, rhs, lhs_hash_opt);
+}
 
 SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
   TORCH_CHECK(
@@ -753,57 +821,25 @@
       " but mask has size ",
       mask.sizes());
 
-  if (!mask.numel()) {
+  if (t.is_same(mask)) {
+    return t;
+  }
+
+  if (!mask.numel() || !mask._nnz()) {
     return mask.clone().to(t.device(), t.scalar_type());
   }
 
   if (t.layout() == at::kSparse) {
-    TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
-                "sparse_mask(): the number of sparse dimensions in `self` ",
-                "should match that of the `mask`. ",
-                "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
-                "`mask.sparse_dim() == ", mask.sparse_dim(), "`.");
-
-    using OptTensor = c10::optional<Tensor>;
-
-    const auto wrapped_tensor = [](const Tensor& t,
-                                   const OptTensor& indices = c10::nullopt,
-                                   const OptTensor& values = c10::nullopt) -> Tensor {
-      auto res = at::empty({0}, t.options());
-      auto* res_sparse_impl = get_sparse_impl(res);
-      res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
-      const auto res_indices = indices.has_value() ? *indices : t._indices();
-      const auto res_values = values.has_value() ? *values : t._values();
-      res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
-      res_sparse_impl->set_nnz_and_narrow(t._nnz());
-      res._coalesced_(false);
+    if (!t._nnz()) {
+      auto res = mask.clone().to(t.device(), t.scalar_type());
+      res._values().zero_();
       return res;
-    };
-
-    using OptTensor = c10::optional<Tensor>;
-    Tensor lhs;
-    OptTensor lhs_hash_opt;
-
-    std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
-      if (t.is_coalesced()) {
-        return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
-      } else {
-        const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
-        const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
-        // Probably worth having a dedicated kernel for.
-        const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
-        const auto res_values = t._values().index_select(0, argsort_indices_hash);
-        const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
-        // NOTE: res is not necessariy coalesced, but it is sorted.
-        // We mark it as "coalesced" to skip sorting in the intersection kernel.
-        auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
-        return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
-      }
-    }();
-
-    const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
+    }
 
     auto res = at::empty({0}, t.options());
+    Tensor lhs, rhs;
+    OptTensor lhs_hash_opt;
+    std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("sparse_mask", t, mask);
     sparse_mask_intersection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
     return res._coalesced_(mask.is_coalesced());
   }
@@ -816,6 +852,31 @@
   return t.mul(mask_template).to(t.scalar_type());
 }
 
+Tensor sparse_mask_projection(const Tensor& t, const Tensor& mask) {
+  TORCH_INTERNAL_ASSERT(t.is_sparse());
+  TORCH_INTERNAL_ASSERT(mask.is_sparse());
+
+  TORCH_CHECK(
+      mask.sizes().equals(t.sizes()),
+      "_sparse_mask_projection(): operands have incompatible sizes; self has size ",
+      t.sizes(),
+      " but mask has size ",
+      mask.sizes());
+
+  if (!t.numel() || !t._nnz() || !mask._nnz()) {
+    auto res = t.clone();
+    res._values().zero_();
+    return res;
+  }
+
+  auto res = at::empty({0}, t.options());
+  Tensor lhs, rhs;
+  OptTensor lhs_hash_opt;
+  std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("_sparse_mask_projection", mask, t);
+  sparse_mask_projection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
+  return res._coalesced_(t.is_coalesced());
+}
+
 Tensor empty_like_sparse_coo(
     const Tensor& self,
     c10::optional<ScalarType> dtype,
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index f914bab..2e12685 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -465,6 +465,8 @@
 aten::_sparse_log_softmax.out
 aten::_sparse_log_softmax_backward_data
 aten::_sparse_log_softmax_backward_data.out
+aten::_sparse_mask_projection
+aten::_sparse_mask_projection.out
 aten::_sparse_mm_reduce_impl
 aten::_sparse_mm_reduce_impl_backward
 aten::_sparse_softmax
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 1cd06e7..342ab86 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -1451,6 +1451,30 @@
       mat2.layout());
 }
 
+Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) {
+  if (x.is_coalesced() && gx.is_coalesced()) {
+    if (x._nnz() >= gx._nnz()) {
+      // search into x is faster
+      return gx._sparse_mask_projection(x);
+    } else {
+      // search into gx is faster
+      return gx.sparse_mask(x);
+    }
+  } else if (x.is_coalesced()) {
+    return gx.sparse_mask(x);
+  } else if (gx.is_coalesced()) {
+    return gx._sparse_mask_projection(x);
+  } else {
+    if (x._nnz() >= gx._nnz()) {
+      // gx.coalesce() is likely faster
+      return gx.coalesce()._sparse_mask_projection(x);
+    } else {
+      // x.coalesce() is likely faster
+      return gx.sparse_mask(x.coalesce());
+    }
+  }
+}
+
 Tensor sparse_sparse_matmul_backward(
     const Tensor& grad,
     const Tensor& a,
@@ -1475,19 +1499,13 @@
   TORCH_CHECK(
       grad_order == 0 || grad_order == 1,
       ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
-  const auto mask_ones_like = [](const Tensor& t) -> Tensor {
-    return at::sparse_coo_tensor(
-        t._indices(),
-        at::ones({1}, t._values().options()).expand_as(t._values()),
-        t.sizes());
-  };
 
   if (grad_order == 0) {
     auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
-    return a_grad.mul(mask_ones_like(a.coalesce()));
+    return sparse_mask_like_grad(a, a_grad);
   }
   auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
-  return b_grad.mul(mask_ones_like(b.coalesce()));
+  return sparse_mask_like_grad(b, b_grad);
 }
 
 Tensor renorm_backward(
diff --git a/torch/overrides.py b/torch/overrides.py
index 11989e7..c51b35e 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -1331,6 +1331,7 @@
         Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
         Tensor.sparse_dim: lambda self: -1,
         Tensor.sparse_mask: lambda self, mask: -1,
+        Tensor._sparse_mask_projection: lambda self, mask: -1,
         Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
         Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
         Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py
index b057634..eb91a49 100644
--- a/torchgen/static_runtime/generator.py
+++ b/torchgen/static_runtime/generator.py
@@ -126,6 +126,7 @@
         "zero",
         "_sparse_addmm",
         "sparse_mask",
+        "_sparse_mask_projection",
         "_to_dense",
         "_coalesce",
         "_coalesced",