[quant][pyper] Add support for pruned weights in embedding_bag_byte lookup (#47329)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47329
Supports pruned weights along with mapping for the compressed indices
Test Plan:
python test/test_quantization.py TestQuantizedEmbeddingOps
Imported from OSS
Reviewed By: qizzzh
Differential Revision: D24719909
fbshipit-source-id: f998f4039e84bbe1886e492a3bff6aa5f56b6b0f
diff --git a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
index 3327e7d..cf98bc5 100644
--- a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
+++ b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h
@@ -7,14 +7,15 @@
virtual at::Tensor embeddingbag_byte(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets,
- bool sparse,
+ bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) = 0;
virtual at::Tensor embeddingbag_4bit(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets,
- bool sparse,
+ bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) = 0;
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index 0cccf81..765e93b 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -342,14 +342,15 @@
at::Tensor embeddingbag_byte(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets,
- bool sparse,
+ bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) override;
at::Tensor embeddingbag_4bit(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets,
- bool sparse,
+ bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) override;
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
index 9f30e55..8aa16fb 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
@@ -21,7 +21,6 @@
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
-
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(indices.dim() == 1);
TORCH_CHECK(offsets.dim() == 1);
@@ -198,6 +197,7 @@
const at::Tensor& offsets,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
TORCH_CHECK(weight.dim() == 2);
@@ -208,6 +208,15 @@
const auto indices_data = indices.data_ptr<IndexType>();
auto offsets_data = offsets.data_ptr<OffsetType>();
+ // Get compressed indices for pruned_weights.
+ int32_t* compressed_indices_mapping_data = nullptr;
+ int compressed_index_size = 0;
+ if (pruned_weights) {
+ compressed_index_size = compressed_indices_mapping.value().numel();
+ compressed_indices_mapping_data =
+ compressed_indices_mapping.value().data_ptr<int32_t>();
+ }
+
const int64_t N = weight.size(0);
const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias
const int64_t M = offsets.size(0);
@@ -233,33 +242,66 @@
auto output = at::empty(shape, weight.options().dtype(at::kFloat));
auto* output_data = output.data_ptr<float>();
+ const int index_size = indices.numel();
#ifdef USE_FBGEMM
- auto kernel_i8 =
- fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType>(
- /*block_size=*/D,
- /*has_weight=*/per_sample_weights_.has_value(),
- /*normalize_by_lengths=*/false,
- /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- /*is_weight_positional=*/false,
- /*use_offsets=*/true);
+ if (!pruned_weights) {
+ auto kernel_i8 =
+ fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType>(
+ /*block_size=*/D,
+ /*has_weight=*/per_sample_weights_.has_value(),
+ /*normalize_by_lengths=*/false,
+ /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+ /*is_weight_positional=*/false,
+ /*use_offsets=*/true);
- at::parallel_for(0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
- bool success = kernel_i8(
- /*output_size=*/end_idx - start_idx,
- /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
- /*data_size=*/N,
+ at::parallel_for(
+ 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+ bool success = kernel_i8(
+ /*output_size=*/end_idx - start_idx,
+ /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
+ /*data_size=*/N,
+ /*input=*/weight_data,
+ /*indices=*/indices_data + offsets_data[start_idx],
+ /*offsets_or_lengths=*/offsets_data + start_idx,
+ /*weights=*/
+ per_sample_weights_
+ ? per_sample_weights_.value().data_ptr<float>() +
+ offsets_data[start_idx]
+ : nullptr,
+ /*out=*/output_data + start_idx * D);
+
+ TORCH_CHECK(
+ success,
+ "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
+ });
+ } else {
+ // pruned weights
+ auto kernel_i8_sparse = fbgemm::
+ GenerateEmbeddingSpMDMRowWiseSparse<uint8_t, IndexType, OffsetType>(
+ /*block_size=*/D,
+ /*has_weight=*/per_sample_weights_.has_value(),
+ /*normalize_by_lengths=*/false,
+ /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+ /*is_weight_positional=*/false,
+ /*use_offsets=*/true);
+
+ auto success = kernel_i8_sparse(
+ /*output_size=*/output_size,
+ /*index_size=*/index_size,
+ /*data_size=*/compressed_index_size,
/*input=*/weight_data,
- /*indices=*/indices_data + offsets_data[start_idx],
- /*offsets_or_lengths=*/offsets_data + start_idx,
+ /*indices=*/indices_data,
+ /*offsets=*/offsets_data,
/*weights=*/
- per_sample_weights_ ? per_sample_weights_.value().data_ptr<float>() +
- offsets_data[start_idx]
- : nullptr,
- /*out=*/output_data + start_idx * D);
-
+ per_sample_weights_.has_value()
+ ? per_sample_weights_.value().data_ptr<float>()
+ : nullptr,
+ /*output=*/output_data,
+ /*compressed_indices_table=*/compressed_indices_mapping_data);
TORCH_CHECK(
- success, "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
- });
+ success,
+ "FBGEMM GenerateEmbeddingSpMDMRowWiseSparse kernel failed for 8-bit input");
+ }
#endif
// TODO add default (non-FBGEMM) implementation.
return output;
@@ -271,6 +313,7 @@
const c10::optional<at::Tensor>& offsets_in,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
TORCH_CHECK(
offsets_in.has_value(),
@@ -297,6 +340,7 @@
offsets,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
} else if (
indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) {
@@ -306,6 +350,7 @@
offsets,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
} else if (
indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) {
@@ -315,6 +360,7 @@
offsets,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
}
@@ -325,6 +371,7 @@
offsets,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
}
@@ -400,6 +447,7 @@
const c10::optional<at::Tensor>& offsets_in,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
return embedding_bag_byte_helper(
packed_w.contiguous(),
@@ -407,6 +455,7 @@
offsets_in,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
}
@@ -439,6 +488,7 @@
const int64_t /* mode */,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
+ const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
return embedding_bag_byte_helper(
weight.contiguous(),
@@ -446,6 +496,7 @@
offsets_in,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
}
@@ -488,6 +539,7 @@
offsets,
pruned_weights,
per_sample_weights_,
+ compressed_indices_mapping,
include_last_offset);
} else if (bit_rate == 4) {
return packed_weight->embeddingbag_4bit(
@@ -516,7 +568,7 @@
at::Tensor output;
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
- indices, offsets, pruned_weights, c10::nullopt, false);
+ indices, offsets, pruned_weights, c10::nullopt, c10::nullopt, false);
} else {
TORCH_INTERNAL_ASSERT(
"Currently only support 8-bit embedding quantization");
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index c09501d..91a275a 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -126,7 +126,7 @@
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"));
- m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor"));
+ m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
diff --git a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
index 4aba09e..5281c43 100644
--- a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
+++ b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py
@@ -176,13 +176,19 @@
low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \
self.enable_per_sample_weights else None
+ self.compressed_indices = None
+
+ if self.is_pruned_weights:
+ self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights)
+
self.op_func = op_func
def forward(self):
return self.op_func(self.prepacked_weights, self.indices, self.offsets,
mode=0, per_sample_weights=self.per_sample_weights,
include_last_offset=self.include_last_offset,
- pruned_weights=self.is_pruned_weights)
+ pruned_weights=self.is_pruned_weights,
+ compressed_indices_mapping=self.compressed_indices)
op_bench.generate_pt_tests_from_op_list(four_bit_rowwise_ops,
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 9f1f6cc..4d22c27 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -116,7 +116,7 @@
("quantized::embedding_bag_byte", datetime.date(2020, 10, 15)),
("quantized::embedding_bag_4bit", datetime.date(2020, 10, 15)),
("quantized::embedding_byte", datetime.date(2020, 10, 15)),
- ("quantized::embedding_bag_byte_rowwise_offsets", datetime.date(2020, 10, 15)),
+ ("quantized::embedding_bag_byte_rowwise_offsets", datetime.date(2020, 11, 15)),
("quantized::embedding_bag_4bit_rowwise_offsets", datetime.date(2020, 10, 15)),
("aten::_foreach_sub_scalar_list", datetime.date(2020, 11, 10)),
("aten::_foreach_add_scalar_list_", datetime.date(2020, 11, 10)),
diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py
index 61dec79..fcc4e68 100644
--- a/test/quantization/test_quantized_op.py
+++ b/test/quantization/test_quantized_op.py
@@ -2969,7 +2969,7 @@
embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
enable_per_sample_weights,
- include_last_offset, prune_weights, sparsity, atol, rtol):
+ include_last_offset, sparsity, atol, rtol):
pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
if bit_rate == 4:
@@ -3026,7 +3026,8 @@
mapping_table = np.zeros(num_embeddings, dtype=np.int32)
pruned_weights = weights
- if prune_weights and bit_rate == 4:
+ prune_weights = sparsity > 0
+ if prune_weights:
# Prune and generate mapping table
num_compressed_rows = 0
unpruned_ids = []
@@ -3041,22 +3042,15 @@
unpruned_ids.append(i)
q_weights = q_weights[unpruned_ids]
pruned_weights = weights[unpruned_ids]
- result = pt_op(q_weights,
- indices.int() if use_32bit_indices else indices,
- offsets.int() if use_32bit_offsets else offsets,
- mode=0,
- pruned_weights=prune_weights,
- per_sample_weights=per_sample_weights,
- compressed_indices_mapping=torch.tensor(mapping_table),
- include_last_offset=include_last_offset)
- else:
- result = pt_op(q_weights,
- indices.int() if use_32bit_indices else indices,
- offsets.int() if use_32bit_offsets else offsets,
- mode=0,
- pruned_weights=prune_weights,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset)
+
+ result = pt_op(q_weights,
+ indices.int() if use_32bit_indices else indices,
+ offsets.int() if use_32bit_offsets else offsets,
+ mode=0,
+ pruned_weights=prune_weights,
+ per_sample_weights=per_sample_weights,
+ compressed_indices_mapping=torch.tensor(mapping_table),
+ include_last_offset=include_last_offset)
reference_result = get_reference_result(
num_embeddings, embedding_dim, include_last_offset, weights,
@@ -3098,18 +3092,20 @@
use_32bit_indices=st.booleans(),
use_32bit_offsets=st.booleans(),
enable_per_sample_weights=st.booleans(),
- include_last_offset=st.booleans())
+ include_last_offset=st.booleans(),
+ sparsity=st.sampled_from([0.0, 0.5, 0.7]))
def test_embedding_bag_byte(self, num_embeddings,
embedding_dim, num_offsets,
use_32bit_indices,
use_32bit_offsets,
enable_per_sample_weights,
- include_last_offset):
+ include_last_offset,
+ sparsity):
self.embedding_bag_rowwise_offsets_run(
8, num_embeddings, embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
- enable_per_sample_weights, include_last_offset, prune_weights=False,
- sparsity=0, atol=0.005, rtol=1e-3)
+ enable_per_sample_weights, include_last_offset,
+ sparsity=sparsity, atol=0.005, rtol=1e-3)
""" Tests the correctness of the embedding_bag_4bit quantized operator """
@given(num_embeddings=st.integers(10, 100),
@@ -3130,7 +3126,7 @@
embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
enable_per_sample_weights,
- include_last_offset, True, sparsity=sparsity,
+ include_last_offset, sparsity=sparsity,
atol=0.1, rtol=1e-2)
""" Tests the correctness of the quantized embedding lookup operator """
diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
index efad31f..aaaaf61 100644
--- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
@@ -360,10 +360,7 @@
embedding_bag_inputs[9]); // per_sample_weights
}
- if (op_name == "embedding_bag_4bit") {
- // 4-bit op has an extra input compressed_indices_mapping
- qembedding_bag_inputs.push_back(none);
- }
+ qembedding_bag_inputs.push_back(none); // compressed_indices_mapping
qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 1]);
Node* qembedding_bag =