[quant][pyper] Add support for pruned weights in embedding_bag_byte lookup (#47329)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47329

Supports pruned weights along with mapping for the compressed indices

Test Plan:
python test/test_quantization.py TestQuantizedEmbeddingOps

Imported from OSS

Reviewed By: qizzzh

Differential Revision: D24719909

fbshipit-source-id: f998f4039e84bbe1886e492a3bff6aa5f56b6b0f
diff --git a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
index 3327e7d..cf98bc5 100644
--- a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
+++ b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
@@ -7,14 +7,15 @@
   virtual at::Tensor embeddingbag_byte(
     const at::Tensor& indices,
     const c10::optional<at::Tensor>& offsets,
-    bool sparse,
+    bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
+    const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) = 0;
 
   virtual at::Tensor embeddingbag_4bit(
     const at::Tensor& indices,
     const c10::optional<at::Tensor>& offsets,
-    bool sparse,
+    bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
     const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) = 0;
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index 0cccf81..765e93b 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -342,14 +342,15 @@
   at::Tensor embeddingbag_byte(
     const at::Tensor& indices,
     const c10::optional<at::Tensor>& offsets,
-    bool sparse,
+    bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
+    const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) override;
 
   at::Tensor embeddingbag_4bit(
     const at::Tensor& indices,
     const c10::optional<at::Tensor>& offsets,
-    bool sparse,
+    bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
     const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) override;
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
index 9f30e55..8aa16fb 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
@@ -21,7 +21,6 @@
     const c10::optional<at::Tensor>& per_sample_weights_,
     const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) {
-
   TORCH_CHECK(weight.dim() == 2);
   TORCH_CHECK(indices.dim() == 1);
   TORCH_CHECK(offsets.dim() == 1);
@@ -198,6 +197,7 @@
     const at::Tensor& offsets,
     bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
+    const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) {
   TORCH_CHECK(weight.scalar_type() == at::kByte);
   TORCH_CHECK(weight.dim() == 2);
@@ -208,6 +208,15 @@
   const auto indices_data = indices.data_ptr<IndexType>();
   auto offsets_data = offsets.data_ptr<OffsetType>();
 
+  // Get compressed indices for pruned_weights.
+  int32_t* compressed_indices_mapping_data = nullptr;
+  int compressed_index_size = 0;
+  if (pruned_weights) {
+    compressed_index_size = compressed_indices_mapping.value().numel();
+    compressed_indices_mapping_data =
+        compressed_indices_mapping.value().data_ptr<int32_t>();
+  }
+
   const int64_t N = weight.size(0);
   const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias
   const int64_t M = offsets.size(0);
@@ -233,33 +242,66 @@
   auto output = at::empty(shape, weight.options().dtype(at::kFloat));
   auto* output_data = output.data_ptr<float>();
 
+  const int index_size = indices.numel();
 #ifdef USE_FBGEMM
-  auto kernel_i8 =
-      fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType>(
-          /*block_size=*/D,
-          /*has_weight=*/per_sample_weights_.has_value(),
-          /*normalize_by_lengths=*/false,
-          /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-          /*is_weight_positional=*/false,
-          /*use_offsets=*/true);
+  if (!pruned_weights) {
+    auto kernel_i8 =
+        fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType>(
+            /*block_size=*/D,
+            /*has_weight=*/per_sample_weights_.has_value(),
+            /*normalize_by_lengths=*/false,
+            /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+            /*is_weight_positional=*/false,
+            /*use_offsets=*/true);
 
-  at::parallel_for(0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
-    bool success = kernel_i8(
-        /*output_size=*/end_idx - start_idx,
-        /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
-        /*data_size=*/N,
+    at::parallel_for(
+        0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+          bool success = kernel_i8(
+              /*output_size=*/end_idx - start_idx,
+              /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
+              /*data_size=*/N,
+              /*input=*/weight_data,
+              /*indices=*/indices_data + offsets_data[start_idx],
+              /*offsets_or_lengths=*/offsets_data + start_idx,
+              /*weights=*/
+              per_sample_weights_
+                  ? per_sample_weights_.value().data_ptr<float>() +
+                      offsets_data[start_idx]
+                  : nullptr,
+              /*out=*/output_data + start_idx * D);
+
+          TORCH_CHECK(
+              success,
+              "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
+        });
+  } else {
+    // pruned weights
+    auto kernel_i8_sparse = fbgemm::
+        GenerateEmbeddingSpMDMRowWiseSparse<uint8_t, IndexType, OffsetType>(
+            /*block_size=*/D,
+            /*has_weight=*/per_sample_weights_.has_value(),
+            /*normalize_by_lengths=*/false,
+            /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+            /*is_weight_positional=*/false,
+            /*use_offsets=*/true);
+
+    auto success = kernel_i8_sparse(
+        /*output_size=*/output_size,
+        /*index_size=*/index_size,
+        /*data_size=*/compressed_index_size,
         /*input=*/weight_data,
-        /*indices=*/indices_data + offsets_data[start_idx],
-        /*offsets_or_lengths=*/offsets_data + start_idx,
+        /*indices=*/indices_data,
+        /*offsets=*/offsets_data,
         /*weights=*/
-        per_sample_weights_ ? per_sample_weights_.value().data_ptr<float>() +
-                offsets_data[start_idx]
-                            : nullptr,
-        /*out=*/output_data + start_idx * D);
-
+        per_sample_weights_.has_value()
+            ? per_sample_weights_.value().data_ptr<float>()
+            : nullptr,
+        /*output=*/output_data,
+        /*compressed_indices_table=*/compressed_indices_mapping_data);
     TORCH_CHECK(
-        success, "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
-  });
+        success,
+        "FBGEMM GenerateEmbeddingSpMDMRowWiseSparse kernel failed for 8-bit input");
+  }
 #endif
   // TODO add default (non-FBGEMM) implementation.
   return output;
@@ -271,6 +313,7 @@
     const c10::optional<at::Tensor>& offsets_in,
     bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
+    const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) {
   TORCH_CHECK(
       offsets_in.has_value(),
@@ -297,6 +340,7 @@
         offsets,
         pruned_weights,
         per_sample_weights_,
+        compressed_indices_mapping,
         include_last_offset);
   } else if (
       indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) {
@@ -306,6 +350,7 @@
         offsets,
         pruned_weights,
         per_sample_weights_,
+        compressed_indices_mapping,
         include_last_offset);
   } else if (
       indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) {
@@ -315,6 +360,7 @@
         offsets,
         pruned_weights,
         per_sample_weights_,
+        compressed_indices_mapping,
         include_last_offset);
   }
 
@@ -325,6 +371,7 @@
       offsets,
       pruned_weights,
       per_sample_weights_,
+      compressed_indices_mapping,
       include_last_offset);
 }
 
@@ -400,6 +447,7 @@
     const c10::optional<at::Tensor>& offsets_in,
     bool pruned_weights,
     const c10::optional<at::Tensor>& per_sample_weights_,
+    const c10::optional<at::Tensor>& compressed_indices_mapping,
     bool include_last_offset) {
   return embedding_bag_byte_helper(
       packed_w.contiguous(),
@@ -407,6 +455,7 @@
       offsets_in,
       pruned_weights,
       per_sample_weights_,
+      compressed_indices_mapping,
       include_last_offset);
 }
 
@@ -439,6 +488,7 @@
     const int64_t /* mode */,
     bool pruned_weights,
     const c10::optional<Tensor>& per_sample_weights_,
+    const c10::optional<Tensor>& compressed_indices_mapping,
     bool include_last_offset) {
   return embedding_bag_byte_helper(
       weight.contiguous(),
@@ -446,6 +496,7 @@
       offsets_in,
       pruned_weights,
       per_sample_weights_,
+      compressed_indices_mapping,
       include_last_offset);
 }
 
@@ -488,6 +539,7 @@
           offsets,
           pruned_weights,
           per_sample_weights_,
+          compressed_indices_mapping,
           include_last_offset);
     } else if (bit_rate == 4) {
       return packed_weight->embeddingbag_4bit(
@@ -516,7 +568,7 @@
     at::Tensor output;
     if (bit_rate == 8) {
       return packed_weight->embeddingbag_byte(
-          indices, offsets, pruned_weights, c10::nullopt, false);
+          indices, offsets, pruned_weights, c10::nullopt, c10::nullopt, false);
     } else {
       TORCH_INTERNAL_ASSERT(
           "Currently only support 8-bit embedding quantization");
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index c09501d..91a275a 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -126,7 +126,7 @@
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"));
-  m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor"));
+  m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
diff --git a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
index 4aba09e..5281c43 100644
--- a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
+++ b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
@@ -176,13 +176,19 @@
             low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \
             self.enable_per_sample_weights else None
 
+        self.compressed_indices = None
+
+        if self.is_pruned_weights:
+            self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights)
+
         self.op_func = op_func
 
     def forward(self):
         return self.op_func(self.prepacked_weights, self.indices, self.offsets,
                             mode=0, per_sample_weights=self.per_sample_weights,
                             include_last_offset=self.include_last_offset,
-                            pruned_weights=self.is_pruned_weights)
+                            pruned_weights=self.is_pruned_weights,
+                            compressed_indices_mapping=self.compressed_indices)
 
 
 op_bench.generate_pt_tests_from_op_list(four_bit_rowwise_ops,
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 9f1f6cc..4d22c27 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -116,7 +116,7 @@
     ("quantized::embedding_bag_byte", datetime.date(2020, 10, 15)),
     ("quantized::embedding_bag_4bit", datetime.date(2020, 10, 15)),
     ("quantized::embedding_byte", datetime.date(2020, 10, 15)),
-    ("quantized::embedding_bag_byte_rowwise_offsets", datetime.date(2020, 10, 15)),
+    ("quantized::embedding_bag_byte_rowwise_offsets", datetime.date(2020, 11, 15)),
     ("quantized::embedding_bag_4bit_rowwise_offsets", datetime.date(2020, 10, 15)),
     ("aten::_foreach_sub_scalar_list", datetime.date(2020, 11, 10)),
     ("aten::_foreach_add_scalar_list_", datetime.date(2020, 11, 10)),
diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py
index 61dec79..fcc4e68 100644
--- a/test/quantization/test_quantized_op.py
+++ b/test/quantization/test_quantized_op.py
@@ -2969,7 +2969,7 @@
             embedding_dim, num_offsets,
             use_32bit_indices, use_32bit_offsets,
             enable_per_sample_weights,
-            include_last_offset, prune_weights, sparsity, atol, rtol):
+            include_last_offset, sparsity, atol, rtol):
         pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
         pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
         if bit_rate == 4:
@@ -3026,7 +3026,8 @@
 
         mapping_table = np.zeros(num_embeddings, dtype=np.int32)
         pruned_weights = weights
-        if prune_weights and bit_rate == 4:
+        prune_weights = sparsity > 0
+        if prune_weights:
             # Prune and generate mapping table
             num_compressed_rows = 0
             unpruned_ids = []
@@ -3041,22 +3042,15 @@
                     unpruned_ids.append(i)
             q_weights = q_weights[unpruned_ids]
             pruned_weights = weights[unpruned_ids]
-            result = pt_op(q_weights,
-                           indices.int() if use_32bit_indices else indices,
-                           offsets.int() if use_32bit_offsets else offsets,
-                           mode=0,
-                           pruned_weights=prune_weights,
-                           per_sample_weights=per_sample_weights,
-                           compressed_indices_mapping=torch.tensor(mapping_table),
-                           include_last_offset=include_last_offset)
-        else:
-            result = pt_op(q_weights,
-                           indices.int() if use_32bit_indices else indices,
-                           offsets.int() if use_32bit_offsets else offsets,
-                           mode=0,
-                           pruned_weights=prune_weights,
-                           per_sample_weights=per_sample_weights,
-                           include_last_offset=include_last_offset)
+
+        result = pt_op(q_weights,
+                       indices.int() if use_32bit_indices else indices,
+                       offsets.int() if use_32bit_offsets else offsets,
+                       mode=0,
+                       pruned_weights=prune_weights,
+                       per_sample_weights=per_sample_weights,
+                       compressed_indices_mapping=torch.tensor(mapping_table),
+                       include_last_offset=include_last_offset)
 
         reference_result = get_reference_result(
             num_embeddings, embedding_dim, include_last_offset, weights,
@@ -3098,18 +3092,20 @@
            use_32bit_indices=st.booleans(),
            use_32bit_offsets=st.booleans(),
            enable_per_sample_weights=st.booleans(),
-           include_last_offset=st.booleans())
+           include_last_offset=st.booleans(),
+           sparsity=st.sampled_from([0.0, 0.5, 0.7]))
     def test_embedding_bag_byte(self, num_embeddings,
                                 embedding_dim, num_offsets,
                                 use_32bit_indices,
                                 use_32bit_offsets,
                                 enable_per_sample_weights,
-                                include_last_offset):
+                                include_last_offset,
+                                sparsity):
         self.embedding_bag_rowwise_offsets_run(
             8, num_embeddings, embedding_dim, num_offsets,
             use_32bit_indices, use_32bit_offsets,
-            enable_per_sample_weights, include_last_offset, prune_weights=False,
-            sparsity=0, atol=0.005, rtol=1e-3)
+            enable_per_sample_weights, include_last_offset,
+            sparsity=sparsity, atol=0.005, rtol=1e-3)
 
     """ Tests the correctness of the embedding_bag_4bit quantized operator """
     @given(num_embeddings=st.integers(10, 100),
@@ -3130,7 +3126,7 @@
                                                embedding_dim, num_offsets,
                                                use_32bit_indices, use_32bit_offsets,
                                                enable_per_sample_weights,
-                                               include_last_offset, True, sparsity=sparsity,
+                                               include_last_offset, sparsity=sparsity,
                                                atol=0.1, rtol=1e-2)
 
     """ Tests the correctness of the quantized embedding lookup operator """
diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
index efad31f..aaaaf61 100644
--- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
@@ -360,10 +360,7 @@
         embedding_bag_inputs[9]); // per_sample_weights
   }
 
-  if (op_name == "embedding_bag_4bit") {
-    // 4-bit op has an extra input compressed_indices_mapping
-    qembedding_bag_inputs.push_back(none);
-  }
+  qembedding_bag_inputs.push_back(none); // compressed_indices_mapping
   qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 1]);
 
   Node* qembedding_bag =